diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d7228dd..4410abd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,11 +21,11 @@ jobs: strategy: fail-fast: false matrix: - rust: [ "stable", "beta", "nightly", "1.65" ] # MSRV + rust: [ "stable", "beta", "nightly", "1.80" ] # MSRV flags: [ "--no-default-features", "", "--all-features" ] exclude: # Skip because some features have highest MSRV. - - rust: "1.65" # MSRV + - rust: "1.80" # MSRV flags: "--all-features" steps: - uses: actions/checkout@v3 @@ -37,10 +37,10 @@ jobs: cache-on-failure: true # Only run tests on the latest stable and above - name: check - if: ${{ matrix.rust == '1.65' }} # MSRV + if: ${{ matrix.rust == '1.80' }} # MSRV run: cargo check --workspace ${{ matrix.flags }} - name: test - if: ${{ matrix.rust != '1.65' }} # MSRV + if: ${{ matrix.rust != '1.80' }} # MSRV run: cargo test --workspace ${{ matrix.flags }} coverage: diff --git a/Cargo.toml b/Cargo.toml index db6b24d..d2da35f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,79 +1,61 @@ [package] -authors = ["Programatik ", "Megsdevs "] +authors = ["0xAlcibiades "] categories = ["asynchronous", "network-programming", "web-programming"] description = "High level server for hyper and tower." edition = "2021" -homepage = "https://github.com/valorem-labs-inc/hyper-server" -keywords = ["axum", "tonic", "hyper", "tower", "server"] +homepage = "https://github.com/warlock-labs/hyper-server" +keywords = ["tcp", "tls", "http", "hyper", "tokio"] license = "MIT" name = "hyper-server" readme = "README.md" repository = "https://github.com/valorem-labs-inc/hyper-server" -version = "0.5.3" +version = "0.7.0" -[features] -default = [] -tls-rustls = ["arc-swap", "pin-project-lite", "rustls", "rustls-pemfile", "tokio/fs", "tokio/time", "tokio-rustls"] -tls-openssl = ["openssl", "tokio-openssl", "pin-project-lite"] -proxy-protocol = ["ppp", "pin-project-lite"] +[target.'cfg(not(target_env = "msvc"))'.dependencies] +tikv-jemallocator = { version = "0.6", optional = true } [dependencies] - -# optional dependencies -## rustls -arc-swap = { version = "1", optional = true } -bytes = "1" -futures-util = { version = "0.3", default-features = false, features = ["alloc"] } -http = "0.2" -http-body = "0.4" -hyper = { version = "0.14.27", features = ["http1", "http2", "server", "runtime"] } - -## openssl -openssl = { version = "0.10", optional = true } -pin-project-lite = { version = "0.2", optional = true } -rustls = { version = "0.21", features = ["dangerous_configuration"], optional = true } -rustls-pemfile = { version = "1", optional = true } -tokio = { version = "1", features = ["macros", "net", "sync"] } -tokio-openssl = { version = "0.6", optional = true } -tokio-rustls = { version = "0.24", optional = true } -tower-service = "0.3" - -## proxy-protocol -ppp = { version = "2.2.0", optional = true } +async-stream = "0.3.5" +bytes = "1.7.1" +futures = "0.3.30" +http = "1.1.0" +http-body = "1.0.1" +http-body-util = "0.1.2" +hyper = "1.4.1" +hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service", "http2"] } +pin-project = "1.1.5" +pprof = { version = "0.13.0", features = ["flamegraph"], optional = true } +ring = "0.17.8" +rustls = { version = "0.23.13", features = ["zlib"] } +rustls-pemfile = "2.1.3" +tokio = { version = "1.40.0", features = ["net", "macros", "rt-multi-thread", "time"] } +tokio-rustls = "0.26.0" +tokio-stream = { version = "0.1.16", features = ["net"] } +tokio-util = "0.7.12" +tower = { version = "0.5.1", features = ["util"] } +tracing = "0.1.40" [dev-dependencies] -axum = "0.6" -hyper = { version = "0.14", features = ["full"] } -tokio = { version = "1", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.4.4", features = ["add-extension"] } - -[package.metadata.docs.rs] -all-features = true -cargo-args = ["-Zunstable-options", "-Zrustdoc-scrape-examples"] -rustdoc-args = ["--cfg", "docsrs"] +criterion = { version = "0.5.1", features = ["html_reports", "async_tokio"] } +hyper = { version = "1.4.1", features = ["client"] } +hyper-rustls = { version = "0.27.3", features = ["http1", "http2"] } +hyper-util = { version = "0.1.8", features = ["client", "client-legacy", "http2"] } +num_cpus = "1.16.0" +rcgen = "0.13.1" +reqwest = { version = "0.12.7", features = ["rustls-tls", "http2"] } +ring = "0.17.8" +tokio = { version = "1.40", features = ["rt-multi-thread", "net", "test-util", "time"] } +tokio-util = { version = "0.7.12", features = ["compat"] } +tracing-subscriber = "0.3.18" + +[[bench]] +name = "hello_world_tower_hyper_tls_tcp" +harness = false -[[example]] -name = "from_std_listener_rustls" -required-features = ["tls-rustls"] -doc-scrape-examples = true - -[[example]] -name = "http_and_https" -required-features = ["tls-rustls"] -doc-scrape-examples = true - -[[example]] -name = "rustls_reload" -required-features = ["tls-rustls"] -doc-scrape-examples = true - -[[example]] -name = "rustls_server" -required-features = ["tls-rustls"] -doc-scrape-examples = true +[features] +default = [] +jemalloc = ["tikv-jemallocator"] +dev-profiling = ["pprof"] -[[example]] -name = "rustls_session" -required-features = ["tls-rustls"] -doc-scrape-examples = true +[profile.release] +debug = true diff --git a/README.md b/README.md index 5d0b234..44b7e14 100644 --- a/README.md +++ b/README.md @@ -1,67 +1,237 @@ +# hyper-server + [![License](https://img.shields.io/crates/l/hyper-server)](https://choosealicense.com/licenses/mit/) [![Crates.io](https://img.shields.io/crates/v/hyper-server)](https://crates.io/crates/hyper-server) [![Docs](https://img.shields.io/crates/v/hyper-server?color=blue&label=docs)](https://docs.rs/hyper-server/) -![CI](https://github.com/valorem-labs-inc/hyper-server/actions/workflows/CI.yml/badge.svg) -[![codecov](https://codecov.io/gh/valorem-labs-inc/hyper-server/branch/master/graph/badge.svg?token=8W5MEJQSW6)](https://codecov.io/gh/valorem-labs-inc/hyper-server) - -# hyper-server +![CI](https://github.com/warlock-labs/hyper-server/actions/workflows/CI.yml/badge.svg) +[![codecov](https://codecov.io/gh/warlock-labs/hyper-server/branch/master/graph/badge.svg?token=8W5MEJQSW6)](https://codecov.io/gh/warlock-labs/hyper-server) -hyper-server is a high performance [hyper] server implementation designed to -work with [axum], [tonic] and [tower]. +A high-performance, modular server implementation built on [hyper], designed to +work seamlessly with [axum], [tonic], [tower], and other tower-compatible +frameworks. ## Features -- HTTP/1 and HTTP/2 -- HTTPS through [rustls] and openssl. -- High performance through [hyper]. -- Using [tower] make service API. -- Exceptional [axum] compatibility. Likely to work with future [axum] releases. -- Superb [tonic] compatibility. Likely to work with future [tonic] releases. +- HTTP/1 and HTTP/2 support +- TLS/HTTPS through [rustls] +- High performance leveraging [hyper] 1.0 +- Modular architecture based on `tokio::net::TcpListener`, `tokio-rustls::Acceptor`, and `hyper::server::conn::auto` +- Flexible integration with [tower] the ecosystem, supporting various backends: + - [axum] + - [tonic] + - Any `hyper::service::Service`, `tower::Service`, or `tower::Layer` composition + +## Installation + +Add this to your `Cargo.toml`: + +```toml +[dependencies] +hyper-server = "0.7.0" +``` -## Usage Example +## Usage -A simple hello world application can be served like: +Here's an example of how to use hyper-server with a simple tower lambda service via TCP/TLS/HTTP2 transport: ```rust -use axum::{routing::get, Router}; +use std::convert::Infallible; use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use http_body_util::Full; +use hyper::{Request, Response}; +use hyper::body::Incoming; +use hyper_util::rt::TokioExecutor; +use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use hyper_util::service::TowerToHyperService; +use rustls::ServerConfig; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; +use tower::{Layer, ServiceBuilder}; +use tracing::{trace, debug, info}; + +use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; + +// Define a simple service that responds with "Hello, World!" +async fn hello(_: Request) -> Result>, Infallible> { + Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +} + +// Define a Custom middleware to add a header to all responses, for example +struct AddHeaderLayer; + +impl Layer for AddHeaderLayer { + type Service = AddHeaderService; + + fn layer(&self, service: S) -> Self::Service { + AddHeaderService { inner: service } + } +} + +#[derive(Clone)] +struct AddHeaderService { + inner: S, +} -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); +impl tower::Service> for AddHeaderService +where + S: tower::Service, Response=Response>>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + trace!("Adding custom header to response"); + let future = self.inner.call(req); + Box::pin(async move { + let mut resp = future.await?; + resp.headers_mut() + .insert("X-Custom-Header", "Hello from middleware!".parse().unwrap()); + Ok(resp) + }) + } +} - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 8443)); + // 1. Set up the TCP listener + let listener = TcpListener::bind(addr).await?; + info!("Listening on https://{}", addr); + let incoming = TcpListenerStream::new(listener); + + // 2. Create the HTTP connection builder + let builder = HttpConnectionBuilder::new(TokioExecutor::new()); + + // 3. Set up the Tower service with middleware + let svc = tower::service_fn(hello); + let svc = ServiceBuilder::new() + .layer(AddHeaderLayer) // Custom middleware + .service(svc); + + // 4. Convert the Tower service to a Hyper service + let svc = TowerToHyperService::new(svc); + + // 5. Set up TLS config + let certs = load_certs("examples/sample.pem")?; + let key = load_private_key("examples/sample.rsa")?; + + let mut config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + let tls_config = Arc::new(config); + + // 6. Set up graceful shutdown + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn a task to send the shutdown signal after 1 second + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + let _ = shutdown_tx.send(()); + debug!("Shutdown signal sent"); + }); + + // 7. Start the server + info!("Starting HTTPS server..."); + serve_http_with_shutdown( + svc, + incoming, + builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + info!("Shutdown signal received, starting graceful shutdown"); + }), + ) + .await?; + + info!("Server has shut down"); + // Et voilà! + // A flexible, high-performance server with custom services, middleware, + // http, tls, tcp, and graceful shutdown + Ok(()) } ``` -You can find more examples [here](/examples). +For more advanced usage and examples, please refer to, or contribute to, +the [examples directory](/examples). + +## Architecture + +hyper-server provides a layered, composable architecture: + +1. TCP Listening: `tokio::net::TcpListener` +2. TLS (optional): `rustls::TlsAcceptor` +3. HTTP: `hyper_util::server::conn::auto` +4. Service: `hyper::service::Service` +5. Middleware: `tower::Service` +6. Application: Your choice of tower-compatible framework (axum, tonic, etc.) or custom service implementation + +This structure allows for easy customization and extension at each layer. You +can integrate your own implementations at any level of the stack, providing +maximum flexibility for your specific use case. + +## Security + +hyper-server takes security seriously. We use `rustls` for TLS support, which +provides modern, secure defaults. However, please ensure that you configure +your server appropriately for your use case, especially when deploying in +production environments. + +If you discover any security-related issues, please email team@warlock.xyz +instead of using the issue tracker. + +## API + +For detailed API documentation, please refer to the [API docs on docs.rs](https://docs.rs/hyper-server/). ## Minimum Supported Rust Version -hyper-server's MSRV is `1.65`. +hyper-server's MSRV is `1.80`. -## Safety +## Contributing -This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. +We welcome contributions to hyper-server! Our contributing guidelines are +inspired by the Rule of St. Benedict, emphasizing humility, listening, +and community. Before contributing, please familiarize yourself with these +principles at [The Rule of St. Benedict](http://www.benedictfriend.org/the-rule.html). -## License +Key points for contributors: -This project is licensed under the [MIT license](LICENSE). +- Listen first, speak second (Chapter 6) +- Be humble in your contributions (Chapter 7) +- Work diligently and carefully (Chapter 48) +- Treat all code and ideas with respect (Chapter 72) -## Why fork +## License -This project is based on the great work in [axum-server], which is no longer actively maintained. -The rationale for forking is that we use this for critical infrastructure and want to be able to -extend the crate and fix bugs as needed. +This project is licensed under the MIT License — see +the [LICENSE](/LICENSE) file for details. -[axum-server]: https://github.com/programatik29/axum-server [axum]: https://crates.io/crates/axum + [hyper]: https://crates.io/crates/hyper + [rustls]: https://crates.io/crates/rustls + [tower]: https://crates.io/crates/tower + [tonic]: https://crates.io/crates/tonic + +[tungstenite]: https://crates.io/crates/tungstenite \ No newline at end of file diff --git a/benches/hello_world_tower_hyper_tls_tcp.rs b/benches/hello_world_tower_hyper_tls_tcp.rs new file mode 100644 index 0000000..a6f3785 --- /dev/null +++ b/benches/hello_world_tower_hyper_tls_tcp.rs @@ -0,0 +1,436 @@ +//! Hello World Benchmark for hyper-server +//! +//! This module implements a comprehensive benchmark for the hyper-server crate, +//! testing its performance in various scenarios including latency, throughput, +//! and concurrent requests. +//! +//! It uses a very basic echo service that responds with "Hello, World!" to GET requests +//! and echoes back the request body for POST requests. The server is configured with +//! an optimized ECDSA certificate and various TLS performance improvements. +//! It exercises the full stack from Socket → TCP → TLS → HTTP/2 → hyper-server → tower-service. +//! This allows developers of the library to optimize the full stack for performance. +//! The library provides a detailed benchmark report with latency, throughput, and +//! concurrency stress tests. +//! It additionally has provision to generate flamegraphs for each benchmark run. +//! +//! For developers who use hyper-server, this provides a good starting point to +//! understand the performance of the library +//! and how to use it optimally in their applications. + +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; + +use bytes::Bytes; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use http::{Request, Response, StatusCode, Uri}; +use http_body_util::{BodyExt, Empty, Full}; +use hyper::body::Incoming; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::client::legacy::Client; +use hyper_util::rt::{TokioExecutor, TokioTimer}; +use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use hyper_util::service::TowerToHyperService; +use rcgen::{CertificateParams, DistinguishedName, KeyPair}; +use rustls::crypto::aws_lc_rs::Ticketer; +use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; +use rustls::server::ServerSessionMemoryCache; +use rustls::{ClientConfig, RootCertStore, ServerConfig}; +use tokio::net::TcpSocket; +use tokio::runtime::Runtime; +use tokio::sync::oneshot; +use tokio::time::{Duration, Instant}; +use tokio_stream::wrappers::TcpListenerStream; +use tracing::info; + +use hyper_server::serve_http_with_shutdown; + +/// Profiling module for generating flamegraphs during benchmarks +/// +/// This module is only compiled when the "dev-profiling" feature is enabled. +/// It provides a custom profiler that integrates with Criterion to generate +/// flamegraphs for each benchmark run. +#[cfg(feature = "dev-profiling")] +mod profiling { + use std::fs::File; + use std::path::Path; + + use criterion::profiler::Profiler; + use pprof::ProfilerGuard; + + /// Custom profiler for generating flamegraphs + /// + /// This struct implements the `Profiler` trait from Criterion, + /// allowing it to be used as a custom profiler in benchmark runs. + pub struct FlamegraphProfiler<'a> { + /// Sampling frequency for the profiler (in Hz) + frequency: i32, + /// The active profiler instance, if profiling is currently in progress + active_profiler: Option>, + } + + impl<'a> FlamegraphProfiler<'a> { + /// Creates a new `FlamegraphProfiler` instance + /// + /// # Arguments + /// + /// * `frequency` - The sampling frequency for the profiler, in Hz + /// + /// # Returns + /// + /// A new `FlamegraphProfiler` instance + pub fn new(frequency: i32) -> Self { + FlamegraphProfiler { + frequency, + active_profiler: None, + } + } + } + + impl<'a> Profiler for FlamegraphProfiler<'a> { + /// Starts profiling for a benchmark + /// + /// This method is called by Criterion at the start of each benchmark iteration. + /// It creates a new `ProfilerGuard` instance and stores it in `active_profiler`. + /// + /// # Arguments + /// + /// * `_benchmark_id` - The ID of the benchmark (unused in this implementation) + /// * `_benchmark_dir` - The directory for benchmark results (unused in this implementation) + fn start_profiling(&mut self, _benchmark_id: &str, _benchmark_dir: &Path) { + self.active_profiler = Some(ProfilerGuard::new(self.frequency).unwrap()); + } + + /// Stops profiling and generates a flamegraph + /// + /// This method is called by Criterion at the end of each benchmark iteration. + /// It generates a flamegraph from the collected profile data and saves it as an SVG file. + /// + /// # Arguments + /// + /// * `_benchmark_id` - The ID of the benchmark (unused in this implementation) + /// * `benchmark_dir` - The directory where the flamegraph should be saved + fn stop_profiling(&mut self, _benchmark_id: &str, benchmark_dir: &Path) { + // Ensure the benchmark directory exists + std::fs::create_dir_all(benchmark_dir).unwrap(); + + // Define the path for the flamegraph SVG file + let flamegraph_path = benchmark_dir.join("flamegraph.svg"); + + // Create the flamegraph file + let flamegraph_file = File::create(&flamegraph_path) + .expect("File system error while creating flamegraph.svg"); + + // Generate and write the flamegraph if a profiler is active + if let Some(profiler) = self.active_profiler.take() { + profiler + .report() + .build() + .unwrap() + .flamegraph(flamegraph_file) + .expect("Error writing flamegraph"); + } + } + } +} + +/// Holds the TLS configuration for both server and client +struct TlsConfig { + server_config: ServerConfig, + client_config: ClientConfig, +} + +/// Generates a shared TLS configuration for both server and client +/// +/// This function creates a self-signed ECDSA certificate and configures both +/// the server and client to use it. It also applies various optimizations +/// to improve TLS performance. +fn generate_shared_ecdsa_config() -> TlsConfig { + // Generate ECDSA key pair + let key_pair = KeyPair::generate().expect("Failed to generate key pair"); + + // Generate certificate parameters + let mut params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("Failed to create certificate params"); + params.distinguished_name = DistinguishedName::new(); + + // Generate the self-signed certificate + let cert = params + .self_signed(&key_pair) + .expect("Failed to generate self-signed certificate"); + + // Serialize the certificate and private key + let cert_der = cert.der().to_vec(); + let key_der = key_pair.serialize_der(); + + // Create Rustls certificate and private key + let cert = CertificateDer::from(cert_der); + let key = PrivatePkcs8KeyDer::from(key_der); + + // Configure Server + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert.clone()], key.into()) + .expect("Failed to configure server"); + + // Server optimizations + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + server_config.max_fragment_size = Some(16384); + server_config.send_tls13_tickets = 8; // Enable 0.5-RTT data + server_config.session_storage = ServerSessionMemoryCache::new(10240); + server_config.ticketer = Ticketer::new().unwrap(); + server_config.max_early_data_size = 16384; // Enable 0-RTT data + + // Configure Client + let mut root_store = RootCertStore::empty(); + root_store + .add(cert) + .expect("Failed to add certificate to root store"); + + let mut client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + // Client optimizations + client_config.enable_sni = false; // Since we're using localhost + client_config.max_fragment_size = Some(16384); + client_config.enable_early_data = true; // Enable 0-RTT data + client_config.resumption = rustls::client::Resumption::in_memory_sessions(10240); + + TlsConfig { + server_config, + client_config, + } +} + +fn create_optimized_runtime(thread_count: usize) -> io::Result { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(thread_count) + .max_blocking_threads(thread_count * 2) + .enable_all() + .build() +} + +#[inline] +async fn echo(req: Request) -> Result>, hyper::Error> { + match (req.method(), req.uri().path()) { + (&hyper::Method::GET, "/") => Ok(Response::new(Full::new(Bytes::from("Hello, World!")))), + (&hyper::Method::POST, "/echo") => { + let body = req.collect().await?.to_bytes(); + Ok(Response::new(Full::new(body))) + } + _ => { + let mut res = Response::new(Full::new(Bytes::from("Not Found"))); + *res.status_mut() = StatusCode::NOT_FOUND; + Ok(res) + } + } +} + +async fn setup_server( +) -> Result<(TcpListenerStream, SocketAddr), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let socket = TcpSocket::new_v4()?; + + // Optimize TCP parameters + socket.set_send_buffer_size(262_144)?; // 256 KB + socket.set_recv_buffer_size(262_144)?; // 256 KB + socket.set_nodelay(true)?; + socket.set_reuseaddr(true)?; + socket.set_reuseport(true)?; + socket.set_keepalive(true)?; + + socket.bind(addr)?; + let listener = socket.listen(8192)?; // Increased backlog for high-traffic scenarios + + let server_addr = listener.local_addr()?; + let incoming = TcpListenerStream::new(listener); + + Ok((incoming, server_addr)) +} + +async fn start_server( + tls_config: ServerConfig, +) -> Result<(SocketAddr, oneshot::Sender<()>), Box> { + let tls_config = Arc::new(tls_config); + let (incoming, server_addr) = setup_server().await?; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + tokio::spawn(async move { + serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + ) + .await + .unwrap(); + }); + Ok((server_addr, shutdown_tx)) +} + +#[inline] +async fn send_request( + client: &Client< + hyper_rustls::HttpsConnector, + Empty, + >, + url: Uri, +) -> Result<(Duration, usize), Box> { + let start = Instant::now(); + let res = client.get(url).await?; + assert_eq!(res.status(), StatusCode::OK); + let body = res.into_body().collect().await?.to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + Ok((start.elapsed(), body.len())) +} + +#[inline] +async fn concurrent_benchmark( + client: &Client< + hyper_rustls::HttpsConnector, + Empty, + >, + url: Uri, + num_requests: usize, +) -> (Duration, Vec, usize) { + let start = Instant::now(); + let mut handles = Vec::with_capacity(num_requests); + + for _ in 0..num_requests { + let client = client.clone(); + let url = url.clone(); + let handle = tokio::spawn(async move { send_request(&client, url).await }); + handles.push(handle); + } + + let mut request_times = Vec::with_capacity(num_requests); + let mut total_bytes = 0; + + for handle in handles { + if let Ok(Ok((duration, bytes))) = handle.await { + request_times.push(duration); + total_bytes += bytes; + } + } + + let total_time = start.elapsed(); + (total_time, request_times, total_bytes) +} + +fn bench_server(c: &mut Criterion) { + let server_runtime = Arc::new(create_optimized_runtime(num_cpus::get() / 2).unwrap()); + + let (server_addr, shutdown_tx, client) = server_runtime.block_on(async { + // Setup default rustls crypto provider + rustls::crypto::aws_lc_rs::default_provider() + .install_default() + .expect("Failed to install rustls crypto provider"); + let tls_config = generate_shared_ecdsa_config(); + let (server_addr, shutdown_tx) = start_server(tls_config.server_config.clone()) + .await + .expect("Failed to start server"); + info!("Server started on {}", server_addr); + + let https = HttpsConnectorBuilder::new() + .with_tls_config(tls_config.client_config) + .https_or_http() + .enable_all_versions() + .build(); + + let client = Client::builder(TokioExecutor::new()) + // HTTP/2 settings + .http2_only(true) // Force HTTP/2 for consistent benchmarking and to match server config + .http2_initial_stream_window_size(2 * 1024 * 1024) // 2MB, matches server setting for better flow control + .http2_initial_connection_window_size(4 * 1024 * 1024) // 4MB, matches server setting for improved throughput + .http2_adaptive_window(true) // Enable dynamic flow control to optimize performance under varying conditions + .http2_max_frame_size(32 * 1024) // 32KB, matches server setting for larger data chunks + .http2_keep_alive_interval(Duration::from_secs(10)) // Maintain connection health, matching server's 10-second interval + .http2_keep_alive_timeout(Duration::from_secs(30)) // Allow time for keep-alive response, matching server's 30-second timeout + .http2_max_concurrent_reset_streams(2000) // Match server's max concurrent streams for better parallelism + .http2_max_send_buf_size(2 * 1024 * 1024) // 2MB, matches server setting for improved write performance + // Connection pooling settings + .pool_idle_timeout(Duration::from_secs(90)) // Keep connections alive longer for reuse in benchmarks + .pool_max_idle_per_host(2000) // Match max concurrent streams to fully utilize HTTP/2 multiplexing + .timer(TokioTimer::new()) + .pool_timer(TokioTimer::new()) + .build(https); + + (server_addr, shutdown_tx, client) + }); + + let url = Uri::builder() + .scheme("https") + .authority(format!("localhost:{}", server_addr.port())) + .path_and_query("/") + .build() + .expect("Failed to build URI"); + + let mut group = c.benchmark_group("hyper_server"); + group.sample_size(20); + group.measurement_time(Duration::from_secs(30)); + + // Latency + group.throughput(Throughput::Elements(1)); + group.bench_function("serial_latency", |b| { + let client = client.clone(); + let url = url.clone(); + let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + b.to_async(client_runtime) + .iter(|| async { send_request(&client, url.clone()).await.unwrap() }); + }); + + // Concurrency stress test + let concurrent_requests = vec![10, 50, 250, 1250]; + for &num_requests in &concurrent_requests { + group.throughput(Throughput::Elements(num_requests as u64)); + group.bench_with_input( + BenchmarkId::new("concurrent_latency", num_requests), + &num_requests, + |b, &num_requests| { + let client = client.clone(); + let url = url.clone(); + let client_runtime = create_optimized_runtime(num_cpus::get() / 2).unwrap(); + b.to_async(client_runtime).iter(|| async { + concurrent_benchmark(&client, url.clone(), num_requests).await + }); + }, + ); + } + + group.finish(); + + server_runtime.block_on(async { + shutdown_tx.send(()).unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + }); +} + +#[cfg(not(feature = "dev-profiling"))] +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(10) + .measurement_time(Duration::from_secs(30)) + .warm_up_time(Duration::from_secs(5)); + targets = bench_server +} + +#[cfg(feature = "dev-profiling")] +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(10) + .measurement_time(Duration::from_secs(30)) + .warm_up_time(Duration::from_secs(5)) + .with_profiler(profiling::FlamegraphProfiler::new(100)); + targets = bench_server +} + +criterion_main!(benches); diff --git a/examples/configure_addr_incoming.rs b/examples/configure_addr_incoming.rs deleted file mode 100644 index bf8ad58..0000000 --- a/examples/configure_addr_incoming.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Run with `cargo run --example configure_http` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::AddrIncomingConfig; -use std::net::SocketAddr; -use std::time::Duration; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = AddrIncomingConfig::new() - .tcp_nodelay(true) - .tcp_sleep_on_accept_errors(true) - .tcp_keepalive(Some(Duration::from_secs(32))) - .tcp_keepalive_interval(Some(Duration::from_secs(1))) - .tcp_keepalive_retries(Some(1)) - .build(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .addr_incoming_config(config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/configure_http.rs b/examples/configure_http.rs deleted file mode 100644 index cfd79e2..0000000 --- a/examples/configure_http.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Run with `cargo run --example configure_http` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::HttpConfig; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = HttpConfig::new() - .http1_only(true) - .http2_only(false) - .max_buf_size(8192) - .build(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .http_config(config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/from_std_listener.rs b/examples/from_std_listener.rs deleted file mode 100644 index 3360187..0000000 --- a/examples/from_std_listener.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! Run with `cargo run --example from_std_listener` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use std::net::{SocketAddr, TcpListener}; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = TcpListener::bind(addr).unwrap(); - println!("listening on {}", addr); - hyper_server::from_tcp(listener) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/from_std_listener_rustls.rs b/examples/from_std_listener_rustls.rs deleted file mode 100644 index 9011212..0000000 --- a/examples/from_std_listener_rustls.rs +++ /dev/null @@ -1,28 +0,0 @@ -//! Run with `cargo run --all-features --example from_std_listener_rustls` -//! command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::net::{SocketAddr, TcpListener}; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let listener = TcpListener::bind(addr).unwrap(); - println!("listening on {}", addr); - hyper_server::from_tcp_rustls(listener, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/full.rs b/examples/full.rs new file mode 100644 index 0000000..e2a9158 --- /dev/null +++ b/examples/full.rs @@ -0,0 +1,132 @@ +use std::convert::Infallible; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use bytes::Bytes; +use http_body_util::Full; +use hyper::body::Incoming; +use hyper::{Request, Response}; +use hyper_util::rt::TokioExecutor; +use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +use hyper_util::service::TowerToHyperService; +use rustls::ServerConfig; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; +use tower::{Layer, ServiceBuilder}; +use tracing::{debug, info, trace}; + +use hyper_server::{load_certs, load_private_key, serve_http_with_shutdown}; + +// Define a simple service that responds with "Hello, World!" +async fn hello(_: Request) -> Result>, Infallible> { + Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +} + +// Define a Custom middleware to add a header to all responses, for example +struct AddHeaderLayer; + +impl Layer for AddHeaderLayer { + type Service = AddHeaderService; + + fn layer(&self, service: S) -> Self::Service { + AddHeaderService { inner: service } + } +} + +#[derive(Clone)] +struct AddHeaderService { + inner: S, +} + +impl tower::Service> for AddHeaderService +where + S: tower::Service, Response = Response>>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + trace!("Adding custom header to response"); + let future = self.inner.call(req); + Box::pin(async move { + let mut resp = future.await?; + resp.headers_mut() + .insert("X-Custom-Header", "Hello from middleware!".parse().unwrap()); + Ok(resp) + }) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let addr = SocketAddr::from(([127, 0, 0, 1], 8443)); + // 1. Set up the TCP listener + let listener = TcpListener::bind(addr).await?; + info!("Listening on https://{}", addr); + let incoming = TcpListenerStream::new(listener); + + // 2. Create the HTTP connection builder + let builder = HttpConnectionBuilder::new(TokioExecutor::new()); + + // 3. Set up the Tower service with middleware + let svc = tower::service_fn(hello); + let svc = ServiceBuilder::new() + .layer(AddHeaderLayer) // Custom middleware + .service(svc); + + // 4. Convert the Tower service to a Hyper service + let svc = TowerToHyperService::new(svc); + + // 5. Set up TLS config + let certs = load_certs("examples/sample.pem")?; + let key = load_private_key("examples/sample.rsa")?; + + let mut config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + let tls_config = Arc::new(config); + + // 6. Set up graceful shutdown + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn a task to send the shutdown signal after 1 second + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + let _ = shutdown_tx.send(()); + debug!("Shutdown signal sent"); + }); + + // 7. Start the server + info!("Starting HTTPS server..."); + serve_http_with_shutdown( + svc, + incoming, + builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + info!("Shutdown signal received, starting graceful shutdown"); + }), + ) + .await?; + + info!("Server has shut down"); + // Et voilà! + // A flexible, high-performance server with custom services, + // middleware, http, tls, tcp, and graceful shutdown + Ok(()) +} diff --git a/examples/graceful_shutdown.rs b/examples/graceful_shutdown.rs deleted file mode 100644 index 82ec31b..0000000 --- a/examples/graceful_shutdown.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! Run with `cargo run --example graceful_shutdown` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. -//! -//! After 10 seconds: -//! - If there aren't any connections alive, server will shutdown. -//! - If there are connections alive, server will wait until deadline is elapsed. -//! - Deadline is 30 seconds. Server will shutdown anyways when deadline is elapsed. - -use axum::{routing::get, Router}; -use hyper_server::Handle; -use std::{net::SocketAddr, time::Duration}; -use tokio::time::sleep; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let handle = Handle::new(); - - // Spawn a task to gracefully shutdown server. - tokio::spawn(graceful_shutdown(handle.clone())); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .handle(handle) - .serve(app.into_make_service()) - .await - .unwrap(); - - println!("server is shut down"); -} - -async fn graceful_shutdown(handle: Handle) { - // Wait 10 seconds. - sleep(Duration::from_secs(10)).await; - - println!("sending graceful shutdown signal"); - - // Signal the server to shutdown using Handle. - handle.graceful_shutdown(Some(Duration::from_secs(30))); - - // Print alive connection count every second. - loop { - sleep(Duration::from_secs(1)).await; - - println!("alive connections: {}", handle.connection_count()); - } -} diff --git a/examples/hello_world.rs b/examples/hello_world.rs deleted file mode 100644 index d91e25c..0000000 --- a/examples/hello_world.rs +++ /dev/null @@ -1,18 +0,0 @@ -//! Run with `cargo run --example hello_world` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{routing::get, Router}; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/http_and_https.rs b/examples/http_and_https.rs deleted file mode 100644 index 63b05db..0000000 --- a/examples/http_and_https.rs +++ /dev/null @@ -1,52 +0,0 @@ -//! Run with `cargo run --all-features --example http_and_https` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url which should redirect to -//! "https://localhost:3443". - -use axum::{http::uri::Uri, response::Redirect, routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let http = tokio::spawn(http_server()); - let https = tokio::spawn(https_server()); - - // Ignore errors. - let _ = tokio::join!(http, https); -} - -async fn http_server() { - let app = Router::new().route("/", get(http_handler)); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("http listening on {}", addr); - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} - -async fn http_handler(uri: Uri) -> Redirect { - let uri = format!("https://127.0.0.1:3443{}", uri.path()); - - Redirect::temporary(&uri) -} - -async fn https_server() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3443)); - println!("https listening on {}", addr); - hyper_server::bind_rustls(addr, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/multiple_addresses.rs b/examples/multiple_addresses.rs deleted file mode 100644 index 8030586..0000000 --- a/examples/multiple_addresses.rs +++ /dev/null @@ -1,25 +0,0 @@ -use axum::{routing::get, Router}; -use futures_util::future::try_join_all; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; - -#[tokio::main] -async fn main() { - let servers = vec![ - SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 3000), - SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 3000), - ] - .into_iter() - .map(|addr| tokio::spawn(start_server(addr))); - - // Returns the first error if any of the servers return an error. - try_join_all(servers).await.unwrap(); -} - -async fn start_server(addr: SocketAddr) { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - hyper_server::bind(addr) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/remote_address.rs b/examples/remote_address.rs deleted file mode 100644 index fb29c7f..0000000 --- a/examples/remote_address.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! Run with `cargo run --example remote_address` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use axum::{extract::ConnectInfo, routing::get, Router}; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new() - .route("/", get(handler)) - .into_make_service_with_connect_info::(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - - hyper_server::bind(addr).serve(app).await.unwrap(); -} - -async fn handler(ConnectInfo(addr): ConnectInfo) -> String { - format!("your ip address is: {}", addr) -} diff --git a/examples/remote_address_using_tower.rs b/examples/remote_address_using_tower.rs deleted file mode 100644 index 6342963..0000000 --- a/examples/remote_address_using_tower.rs +++ /dev/null @@ -1,27 +0,0 @@ -//! Run with `cargo run --example remote_address_using_tower` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. - -use hyper::{server::conn::AddrStream, Body, Request, Response}; -use std::{convert::Infallible, net::SocketAddr}; -use tower::service_fn; -use tower_http::add_extension::AddExtension; - -#[tokio::main] -async fn main() { - let service = service_fn(|mut req: Request| async move { - let addr: SocketAddr = req.extensions_mut().remove().unwrap(); - let body = Body::from(format!("IP Address: {}", addr)); - - Ok::<_, Infallible>(Response::new(body)) - }); - - hyper_server::bind(SocketAddr::from(([127, 0, 0, 1], 3000))) - .serve(service_fn(|addr: &AddrStream| { - let addr = addr.remote_addr(); - - async move { Ok::<_, Infallible>(AddExtension::new(service, addr)) } - })) - .await - .unwrap(); -} diff --git a/examples/rustls_reload.rs b/examples/rustls_reload.rs deleted file mode 100644 index ba1add6..0000000 --- a/examples/rustls_reload.rs +++ /dev/null @@ -1,52 +0,0 @@ -//! Run with `cargo run --all-features --example rustls_reload` command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. -//! -//! Certificate common name will be "localhost". -//! -//! After 20 seconds, certificate common name will be "reloaded". - -use axum::{routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::{net::SocketAddr, time::Duration}; -use tokio::time::sleep; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - // Spawn a task to reload tls. - tokio::spawn(reload(config.clone())); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind_rustls(addr, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} - -async fn reload(config: RustlsConfig) { - // Wait for 20 seconds. - sleep(Duration::from_secs(20)).await; - - println!("reloading rustls configuration"); - - // Reload rustls configuration from new files. - config - .reload_from_pem_file( - "examples/self-signed-certs/reload/cert.pem", - "examples/self-signed-certs/reload/key.pem", - ) - .await - .unwrap(); - - println!("rustls configuration reloaded"); -} diff --git a/examples/rustls_server.rs b/examples/rustls_server.rs deleted file mode 100644 index c7cf0f2..0000000 --- a/examples/rustls_server.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Run with `cargo run --all-features --example rustls_server` command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. - -use axum::{routing::get, Router}; -use hyper_server::tls_rustls::RustlsConfig; -use std::net::SocketAddr; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind_rustls(addr, config) - .serve(app.into_make_service()) - .await - .unwrap(); -} diff --git a/examples/rustls_session.rs b/examples/rustls_session.rs deleted file mode 100644 index fab499f..0000000 --- a/examples/rustls_session.rs +++ /dev/null @@ -1,80 +0,0 @@ -//! Run with `cargo run --all-features --example rustls_session` command. -//! -//! To connect through browser, navigate to "https://localhost:3000" url. - -use axum::{middleware::AddExtension, routing::get, Extension, Router}; -use futures_util::future::BoxFuture; -use hyper_server::{ - accept::Accept, - tls_rustls::{RustlsAcceptor, RustlsConfig}, -}; -use std::{io, net::SocketAddr, sync::Arc}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_rustls::server::TlsStream; -use tower::Layer; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(handler)); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - - println!("listening on {}", addr); - - let acceptor = CustomAcceptor::new(RustlsAcceptor::new(config)); - let server = hyper_server::bind(addr).acceptor(acceptor); - - server.serve(app.into_make_service()).await.unwrap(); -} - -async fn handler(tls_data: Extension) -> String { - format!("{:?}", tls_data) -} - -#[derive(Debug, Clone)] -struct TlsData { - _hostname: Option>, -} - -#[derive(Debug, Clone)] -struct CustomAcceptor { - inner: RustlsAcceptor, -} - -impl CustomAcceptor { - fn new(inner: RustlsAcceptor) -> Self { - Self { inner } - } -} - -impl Accept for CustomAcceptor -where - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, - S: Send + 'static, -{ - type Stream = TlsStream; - type Service = AddExtension; - type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>; - - fn accept(&self, stream: I, service: S) -> Self::Future { - let acceptor = self.inner.clone(); - - Box::pin(async move { - let (stream, service) = acceptor.accept(stream, service).await?; - let server_conn = stream.get_ref().1; - let sni_hostname = TlsData { - _hostname: server_conn.server_name().map(From::from), - }; - let service = Extension(sni_hostname).layer(service); - - Ok((stream, service)) - }) - } -} diff --git a/examples/sample.pem b/examples/sample.pem new file mode 100644 index 0000000..9e8bc64 --- /dev/null +++ b/examples/sample.pem @@ -0,0 +1,79 @@ +-----BEGIN CERTIFICATE----- +MIIEADCCAmigAwIBAgICAcgwDQYJKoZIhvcNAQELBQAwLDEqMCgGA1UEAwwhcG9u +eXRvd24gUlNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIyMDcwNDE0MzA1OFoX +DTI3MTIyNTE0MzA1OFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDL35qLQLIqswCmHJxyczYF2p0YxXCq +gMvtRcKVElnifPMFrbGCY1aYBmhIiXPGRwhfythAtYfDQsrXFADZd52JPgZCR/u6 +DQMqKD2lcvFQkf7Kee/fNTOuQTQPh1XQx4ntxvicSATwEnuU28NwVnOU//Zzq2xn +Q34gUQNHWp1pN+B1La7emm/Ucgs1/2hMxwCZYUnRoiUoRGXUSzZuWokDOstPNkjc ++AjHmxONgowogmL2jKN9BjBw/8psGoqEOjMO+Lb9iekOCzX4kqHaRUbTlbSAviQu +2Q115xiZCBCZVtNE6DUG25buvpMSEXwpLd96nLywbrSCyueC7cd01/hpAgMBAAGj +gb4wgbswDAYDVR0TAQH/BAIwADALBgNVHQ8EBAMCBsAwHQYDVR0OBBYEFHGnzC5Q +A62Wmv4zfMk/kf/BxHevMEIGA1UdIwQ7MDmAFDMRUvwxXbYDBCxOdQ9xfBnNWUz0 +oR6kHDAaMRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0GCAXswOwYDVR0RBDQwMoIO +dGVzdHNlcnZlci5jb22CFXNlY29uZC50ZXN0c2VydmVyLmNvbYIJbG9jYWxob3N0 +MA0GCSqGSIb3DQEBCwUAA4IBgQBqKNIM/JBGRmGEopm5/WNKV8UoxKPA+2jR020t +RumXMAnJEfhsivF+Zw/rDmSDpmts/3cIlesKi47f13q4Mfj1QytQUDrsuQEyRTrV +Go6BOQQ4dkS+IqnIfSuue70wpvrZHhRHNFdFt9qM5wCLQokXlP988sEWUmyPPCbO +1BEpwWcP1kx+PdY8NKOhMnfq2RfluI/m4MA4NxJqAWajAhIbDNbvP8Ov4a71HPa6 +b1q9qIQE1ut8KycTrm9K32bVbvMHvR/TPUue8W0VvV2rWTGol5TSNgEQb9w6Kyf7 +N5HlRl9kZB4K8ckWH/JVn0pYNBQPgwbcUbJ/jp6w+LHrh+UW75maOY+IGjVICud8 +6Rc5DZZ2+AAbXJQZ1HMPrw9SW/16Eh/A4CIEsvbu9J+7IoSzhgcKFzOCUojzzRSj +iU7w/HsvpltmVCAZcZ/VARFbe1By2wXX2GSw2p2FVC8orXs76QyruPAVgSHCTVes +zzBo6GLScO/3b6uAcPM3MHRGGvE= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIEnzCCAoegAwIBAgIBezANBgkqhkiG9w0BAQsFADAaMRgwFgYDVQQDDA9wb255 +dG93biBSU0EgQ0EwHhcNMjIwNzA0MTQzMDU4WhcNMzIwNzAxMTQzMDU4WjAsMSow +KAYDVQQDDCFwb255dG93biBSU0EgbGV2ZWwgMiBpbnRlcm1lZGlhdGUwggGiMA0G +CSqGSIb3DQEBAQUAA4IBjwAwggGKAoIBgQCsTkd2SKiy3yy20lygOhKfOySo3qpq +TZVrpW11vQ58+6EcetXRnzIIK0HyhPmZrv9XKPpQclJvfY9jADNtu2CSj/v15OSB +Love3GzmXSZz2A8QUZBPWx6HczDG1hFGzrCZPKzpeLnFD1LPsKCUkUOHl1acyy24 +DaCacQJPzPQWbMhbGmYRlDNb+2R2K6UKMAEVe4IOTv2aSIKDGLI+xlaBXYAJj48L +//9eNmR3bMP3kkNKOKaaBk8vnYxKpZ+8ZHeHTmYWR9x1ZoMcbA9lKUwRpKAjY5JJ +NVZMDmjlVQVvvBrvhgz/zgXtfuaQCryZ0f1sEY/zXhdealo3fGVomeoniD4XwA1c +oaUFkbo5IM5HU/pXyAGRerDyhYLgRqQZMIRauvKRPN3jLsPOEQ0+gnXUUTr/YGIE +KY3/Axg4P3hzZCFqJ5IgkgWZr/dKr9p/0cxSUGHTVcpEFOlkKIIIdRuR7Ng5sJml +u7PAMWt6T+x02ORs1/WkyP7LyPQmuugYTicCAwEAAaNeMFwwHQYDVR0OBBYEFDMR +UvwxXbYDBCxOdQ9xfBnNWUz0MCAGA1UdJQEB/wQWMBQGCCsGAQUFBwMBBggrBgEF +BQcDAjAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIB/jANBgkqhkiG9w0BAQsFAAOC +AgEAYzqmX+cNPgVD2HWgbeimUraTpI9JP5P4TbOHWmaJKecoy3Hwr71xyAOGiVXL +urk1ZZe8n++GwuDEgRajN3HO9LR1Pu9qVIzTYIsz0ORRQHxujnF7CxK/I/vrIgde +pddUdHNS0Y0g8J1emH9BgoD8a2YsGX4iDY4S4hIGBbGvQp9z8U/uG1mViAmlXybM +b8bf0dx0tEFUyu8jsQP6nFLY/HhkEcvU6SnOzZHRsFko6NE44VIsHLd2+LS2LCM/ +NfAoTzgvj41M3zQCZapaHZc9KXfdcCvEFaySKGfEZeQTUR5W0FHsF5I4NLGryf5L +h3ENQ1tgBTO5WnqL/5rbgv6va9VionPM5sbEwAcancejnkVs3NoYPIPPgBFjaFmL +hNTpT9H2owdZvEwNDChVS0b8ukNNd4cERtvy0Ohc3mk0LGN0ABzrud0fIqa51LMh +0N3dkPkiZ4XYk4yLJ5EwCrCNNH50QkGCOWpInKIPeSYcALGgBDbCDLv6rV3oSKrV +tHCZQwXVKKgU4AQu7hlHBwJ61cH44ksydOidW3MNq1kDIp7ST8s7gVrItNgFnG+L +Jpo270riwSUlWDY4hXw5Ff5lE+bWCmFyyOkLevDkD9v8M4HdwEVvafYYwn75fCIS +5OnpSeIB08kKqCtW1WBwki0rYJjWqdzI7Z1MQ/AyScAKiGM= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIEsDCCApgCCQCfkxy3a+AgNjANBgkqhkiG9w0BAQsFADAaMRgwFgYDVQQDDA9w +b255dG93biBSU0EgQ0EwHhcNMjIwNzA0MTQzMDU3WhcNMzIwNzAxMTQzMDU3WjAa +MRgwFgYDVQQDDA9wb255dG93biBSU0EgQ0EwggIiMA0GCSqGSIb3DQEBAQUAA4IC +DwAwggIKAoICAQCj6nW8pnN50UsH2NjL97xZKxlXPe5ptXfvqXczMsw0vB3gI4xJ +Tdmrnqo0K+VOH7vh+UXcOj2ZMY2ou6oDDK5Qpu9bvGPBIJH/rC1Ti2+u5Y4KTIUc +jWAtzQJeFn8+oCMfskpLdtlWLRdAuwqNHjvxXdd2JnsX1Wid85U/rG2SNPLGjJAF +xG7xzZC4VSO2WIXTGRMUkZfFc8fhWMjo3GaeF/qYjzfHDPWN/ll/7vfxyXJO/ohw +FzpJSZtKmI+6PLxqB/oFrKfTDQUGzxjfHp187bI3eyUFMJsp18/tLYkLyxSWIg3o +bq7ZVimHd1UG2Vb5Y+5pZkh22jmJ6bAa/kmNNwbsD+5vJhW1myGhmZSxkreYPWnS +6ELrSMvbXccFfTYmdBlWsZx/zUVUzVCPe9jdJki2VXlicohqtvBQqe6LGGO37vvv +Gwu1yzQ/rJy47rnaao7fSxqM8nsDjNR2Ev1v031QpEMWjfgUW0roW3H58RZSx+kU +gzIS2CjJIqKxCp894FUQbC6r0wwAuKltl3ywz5qWkxY0O9bXS0YdEXiri5pdsWjr +84shVVQwnoVD9539CLSdHZjlOCAzvSWHZH6ta2JZjUfYYz8cLyv2c2+y9BYrlvHw +T7U7BqzngUk72gcRXd5+Onp+16gGxpGJqaxqj94Nh/yTUnr2Jd9YaXeFmQIDAQAB +MA0GCSqGSIb3DQEBCwUAA4ICAQBzIRVRt3Yaw60tpkyz/i1xbKCbtC+HqYTEsXvZ +RvZ5X1qyLAcmu4EW9RHXnlLiawDbES6lCMFfdBUK03Wis7socvoFUCBRW337F4z2 +IivHfIge4u+w5ouUKPzcpj6oeuR06tmNytYbno6l8tXJpm1eeO4KNZ0ZtodmyB5D +yLrplFgxTdGGgyvxt8LoeLwGmPCyVt35x/Mz6x2lcq1+r7QJZ9sENhQYuA8UqHrw +fmNoVIMXMEcPLcWtFl6nKTK9LrqAu1jgTBqGGZKRn5CYBBK3pNEGKiOIsZXDbyFS +F59teFpJjyeJTbUbLxXDa15J6ExkHV9wFLEvfu/nzQzg8D9yzczSdbDkE2rrrL+s +Q/H/pIXO/DesCWQ37VALn3B5gm9UBd5uogbSw8eamiwRFLQ0snP80pJQGJoTNn0P +wrLLUf2gsKC2262igiA+imepm5wxbV9XGVZfHJgxCi5Zqrf6aWnjIqD2YtDvAHhs +V8ZWN3QTjdnEcQbG0544rocoLNX/FzmyDgjfZKY5r6wt+FWNc/R4clkF+KxasxqB +HdBs8j0lGV3ujvNXASLq9HI6VxZayrSfkR73hADCXIM/wzynKwMarvA4SXwYX9Pd +cJ4+FMqrevPpamMHUsNndS0KfDTdjDp+TSBf87yiyRkD1Ri4ePslyfNvRyv3Xs7k +47YFzA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/examples/sample.rsa b/examples/sample.rsa new file mode 100644 index 0000000..d1f178a --- /dev/null +++ b/examples/sample.rsa @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAy9+ai0CyKrMAphyccnM2BdqdGMVwqoDL7UXClRJZ4nzzBa2x +gmNWmAZoSIlzxkcIX8rYQLWHw0LK1xQA2XediT4GQkf7ug0DKig9pXLxUJH+ynnv +3zUzrkE0D4dV0MeJ7cb4nEgE8BJ7lNvDcFZzlP/2c6tsZ0N+IFEDR1qdaTfgdS2u +3ppv1HILNf9oTMcAmWFJ0aIlKERl1Es2blqJAzrLTzZI3PgIx5sTjYKMKIJi9oyj +fQYwcP/KbBqKhDozDvi2/YnpDgs1+JKh2kVG05W0gL4kLtkNdecYmQgQmVbTROg1 +BtuW7r6TEhF8KS3fepy8sG60gsrngu3HdNf4aQIDAQABAoIBAFTehqVFj2W7EqAT +9QSn9WtGcHNpbddsunfRvIj2FLj2LuzEO8r9s4Sh1jOsFKgL1e6asJ9vck7UtUAH +sbrV0pzZVx2sfZwb4p9gFRmU2eQigqCjVjnjGdqGhjeYrR62kjKLy96zFGskJpH3 +UkqnkoIKc/v+9qeeLxkg4G6JyFGOFHJAZEraxoGydJk9n/yBEZ/+3W7JUJaGOUNU +M7BYsCS2VOJr+cCqmCk1j8NvYvWWxTPsIXgGJl4EOoskzlzJnYLdh9fPFZu3uOIx +hpm3DBNp6X+qXf1lmx9EdpyeXKpLFIgJM7+nw2uWzxW7XMlRERi+5Tprc/pjrqUq +gpfyvMkCgYEA909QcJpS3qHoWyxGbI1zosVIZXdnj8L+GF/2kEQEU5iEYT+2M1U+ +gCPLr49gNwkD1FdBSCy+Fw20zi35jGmxNwhgp4V94CGYzqwQzpnvgIRBMiAIoEwI +CD5/t34DZ/82u8Gb7UYVrzOD54rJ628Q+tJEJak3TqoShbvcxJC/rXMCgYEA0wmO +SRoxrBE3rFzNQkqHbMHLe9LksW9YSIXdMBjq4DhzQEwI0YgPLajXnsLurqHaJrQA +JPtYkqiJkV7rvJLBo5wxwU+O2JKKa2jcMwuCZ4hOg5oBfK6ES9QJZUL7kDe2vsWy +rL+rnxJheUjDPBTopGHuuc9Nogid35CE0wy7S7MCgYArxB+KLeVofOKv79/uqgHC +1oL/Yegz6uAo1CLAWSki2iTjSPEnmHhdGPic8xSl6LSCyYZGDZT+Y3CR5FT7YmD4 +SkVAoEEsfwWZ3Z2D0n4uEjmvczfTlmD9hIH5qRVVPDcldxfvH64KuWUofslJHvi0 +Sq3AtHeTNknc3Ogu6SbivQKBgQC4ZAsMWHS6MTkBwvwdRd1Z62INyNDFL9JlW4FN +uxfN3cTlkwnJeiY48OOk9hFySDzBwFi3910Gl3fLqrIyy8+hUqIuk4LuO+vxuWdc +uluwdmqTlgZimGFDl/q1nXcMJYHo4fgh9D7R+E9ul2Luph43MtJRS447W2gFpNJJ +TUCA/QKBgQC07GFP2BN74UvL12f+FpZvE/UFtWnSZ8yJSq8oYpIbhmoF5EUF+XdA +E2y3l1cvmDJFo4RNZl+IQIbHACR3y1XOnh4/B9fMEsVQHK3x8exPk1vAk687bBG8 +TVDmdP52XEKHplcVoYKvGzw/wsObLAGyIbJ00t1VPU+7guTPsc+H/w== +-----END RSA PRIVATE KEY----- \ No newline at end of file diff --git a/examples/self-signed-certs/cert.pem b/examples/self-signed-certs/cert.pem deleted file mode 100644 index 8227f32..0000000 --- a/examples/self-signed-certs/cert.pem +++ /dev/null @@ -1,32 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFkzCCA3ugAwIBAgIUQZiKeBISKUZoglT8J8CCPpGbgTkwDQYJKoZIhvcNAQEL -BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM -GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X -DTIxMDgyOTEyMDE0NVoXDTIyMDgyOTEyMDE0NVowWTELMAkGA1UEBhMCVVMxEzAR -BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 -IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8A -MIICCgKCAgEAoeDJnuh1lhcpKCt5VEBqO9JcSoz2wqD3SLj4i2qrEOvqb4X0ZZeN -5GQXQlOG2N6+9FOxTzaTTigTecYzI3hqKn1fiuvaS4EeTC7E1sVOj7tY0yVySjXM -pC/3t1n1s3B25m7eQ0G2JypZFCobGqY0kaRoO+mCTjI4bdCd769shIerCO4Z8FD5 -uj1+hBC7ZY/sqmRkGTLX1ZzkXzaeNeWGlkXKU8/V3qdveFQ/sGe+KoZpOPXb0yR7 -H8zf6NE2CFCNJDhytOkYLOsnvCJOvibJ3kbM2GfI9iCd0/QhQAOcrVhcOgI4aIxr -wP3zvF4PFUhFKEWHqK5IFq41xKyMYu2fw3bmKXg4zsQGcB0avBD7z+7ENEBvLkNI -7O20wKJp8u0RfjStNHWPmWLXPjkadVB5JHJjsktvgNZkbs9ugxhZWW2AzrrIuqwR -NOWnjHE7J3jvcHP6jE5O9LHpnlh6BMoKPsQuRu/bkrD34rNzwH7IX1To1CyDazMR -yhUiARYh43gg6hrrQdVjDFMHd51mgWHtOPzSLb0uzToglAa3FClGlCeaiacu4H2V -EfJrlCbVlftmIub9/EILZ6XpyYWMxt2mm4mCcMtXmBsHolP4lU3keK8AGNFOr3PC -B7NHLNp1RHgx8+Q3kzobJ1Lk+zEjraWPb5gyByUvZySbd/JTGgNCmZsCAwEAAaNT -MFEwHQYDVR0OBBYEFGsIv6GsbDS+dEWwWlA/3TG5Oi88MB8GA1UdIwQYMBaAFGsI -v6GsbDS+dEWwWlA/3TG5Oi88MA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEL -BQADggIBAHhjzP8WtkLJVfZXXUPAAekR7kaqk2hb3hIgDABBJ7xNxcktLOH7V/ng -nhbnwSH5mCkHHXx78TOhWqokHp5wru8K3de5wvAD8uz0UwNDHK5EzqtjYLzxbxAr -ht89WoXGPEZIz6MuOxVYx/HHXdgNEXUcujzfpAfvznVxvzBVqpHNgc7qO8wJd0cG -nit1XubxKoIVTEUjDfxGa2TsmBI7CZ8MLjIyztp/b3txpVl36hPC/uFLwKC780Jc -eO9saA5ISbJh7EaISRr8MKpBpJcraL+055bMjM+kzRFA18NWuuo9Y8fXnXE8e/af -k8FvclVdH/YyezaLkjW7lXjo7QoSXHhAuSzvsGmIsh+HuH+3Fs22AN3aGdmimOmp -7JiNe42mwEpJydwgGlKOysw4ht6MA6yOcQJw73QAYYwusOmNjFZtfCUqJx/JO7mn -Sb1/PW58xYSJhDxdGhoh6Rd3xPMW1T4YwpapkAC/htciK3XkwCcG1VKSmCIErkXf -vllmdahH/QkNooNAHMZl/ipYMik8pp5eRjVjCvpQTDBOI97U0+bgXydHVowP9ExE -dGcm6pP8FU1LyBZdYTdlMRC5Z0L0ltcZn7bqKcyzZB3UcWJv7Uhn3MYbmqGsUVly -a/e3kH2t5pEWRTsrNrRD94LzEYKvcNHy6PYkrgpGjh2G2VBZgNzh ------END CERTIFICATE----- diff --git a/examples/self-signed-certs/key.pem b/examples/self-signed-certs/key.pem deleted file mode 100644 index c329a2d..0000000 --- a/examples/self-signed-certs/key.pem +++ /dev/null @@ -1,52 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQCh4Mme6HWWFyko -K3lUQGo70lxKjPbCoPdIuPiLaqsQ6+pvhfRll43kZBdCU4bY3r70U7FPNpNOKBN5 -xjMjeGoqfV+K69pLgR5MLsTWxU6Pu1jTJXJKNcykL/e3WfWzcHbmbt5DQbYnKlkU -KhsapjSRpGg76YJOMjht0J3vr2yEh6sI7hnwUPm6PX6EELtlj+yqZGQZMtfVnORf -Np415YaWRcpTz9Xep294VD+wZ74qhmk49dvTJHsfzN/o0TYIUI0kOHK06Rgs6ye8 -Ik6+JsneRszYZ8j2IJ3T9CFAA5ytWFw6AjhojGvA/fO8Xg8VSEUoRYeorkgWrjXE -rIxi7Z/DduYpeDjOxAZwHRq8EPvP7sQ0QG8uQ0js7bTAomny7RF+NK00dY+ZYtc+ -ORp1UHkkcmOyS2+A1mRuz26DGFlZbYDOusi6rBE05aeMcTsneO9wc/qMTk70seme -WHoEygo+xC5G79uSsPfis3PAfshfVOjULINrMxHKFSIBFiHjeCDqGutB1WMMUwd3 -nWaBYe04/NItvS7NOiCUBrcUKUaUJ5qJpy7gfZUR8muUJtWV+2Yi5v38QgtnpenJ -hYzG3aabiYJwy1eYGweiU/iVTeR4rwAY0U6vc8IHs0cs2nVEeDHz5DeTOhsnUuT7 -MSOtpY9vmDIHJS9nJJt38lMaA0KZmwIDAQABAoICAHzGnCLU4+4xJBRGjlsW28wI -tgLw7TPQh0uS6GHucrW0YxxbkKrOSx0E2bjSUVrRNzd1W3LHinvwADMZR0nMA2mF -AiQ+8CDLAeOPGULDC29W5Xy7nID/PyI/px25Rd5ujffI9aG6AQHnbopQelvsSREK -PR4RO9OyejSLXXHnMipluLxFa9EFWbjotaBulUQP0Ej24QFbY2rQaGfL3d+FcFxc -pzw7M4tQXGfP6Ne836Q/vtOdDziNIiq87Mq0mIWIMYL9z80K7wuQpywo9bE0jN28 -jSExvoGZWo6J2ydQoXAsb8p286wCsPwtw7Yqek3ZSxVjotGupPp2hhN3PS70IvR5 -wcR+1pGTSzUFkrLurZftR+HNU4GHVGEzmFKtQ1dyBjDdLSkBHx+N3rzvvArMLDKI -hYXc7AgCTR1SkZBBVPFlNZJyicE+x52UGLvnyS5chgqvSsOrkhDu/bK+ISTh+3jZ -8QSnjYuZLQ1q5i3914wKzjSrHbFWuoGullqCk6nvhn2EEDcAVla0ebSYBcrnzKhO -qJogZzUSTpINIKNQlZuohzbS0lrvXuYDRDkZLRaQWKgHGiat7peBazEfd0NTHpIs -2lKovGTWNU8MIvJPONFixIZ0k7Z+s7Oje+dSOoCyCUzA3BT+mmS2Yi180zxrtRBS -LPGooWR3Rfyptx+OJkehAoIBAQDQkoPWIQWdFG1G9x08H49/AjcfGtHbdjeCjNqS -6mbXLzHgQjnUnmKmuqgkSw9IA+l2OqX4dNrKqH9P6Ex9s3HRxTmYt9/0DLT8Thus -04DiusjhUDQYV8pXUBujmVkMEEI8N5RXv0IAd59kaA6kWJLtrnp6mREY2WJicIAJ -BKut0QTC+upnvV2NKYc+Ki5ElB5hqzICr+wBq35ZlxTId7F5iaZeWeljpOodZw06 -KCVIUhmGHNVR0DUqUJ8+j7gstXhXr0MVhAlRg+WhlUvyCm1UhElyyrVgiXjqeqO9 -RO2+/poPNFxylVzYgTi54ydeB378/LcrxFQ7Q3DAW6DSAefHAoIBAQDGsBc6SnXu -WGW2qPWQM1Jm9hGy7ZgB8953kvpSxE1cVkXoOOtaa2HtRurxT55s4nTAzqDV//7R -9OX+JDCMeQLm9oLzGOxaCaq5lGNTNQs+MBPP78wwQrZRhneuG5U0lEYBb+dlkHih -IejR9OK0r0btpwuLWTC/cs2dNMW0J6JwaK6J4JiJC+nJiKyt1W98Vtpz0oLJq/Re -Z/e3sVZF3RLks5WoQsiXYoQ3KFf9koBsImggGm2prrFl9KeZJOVJP0ZeDaRcLGWQ -PRt0nNKuuSRJ5HZF/0TCwUXAtpaftAsr4fhB+/KYVdVrni5FYdfqUX4KH6n9LFSG -VC1OST1JJIeNAoIBAB0H57XMTt24VCWGi9ksg2qoQkfgEcm8QKm5NUsxuTLGbOjM -DwSbLxwJ6xFyKSRa9wnvy94zVajTnzTeHpd4fKU4EHZDUbbEdgSQUqXRoqTsXr2N -zlJ9FbrleZNh6tUVBkMfcVRtWKB8BgGRwkf51CmlGYMq/wg4actN4WRf9A1zhHgn -OK1L3FOjriFm+Z2uCDSMAaACIJVy61lJACmPD3LdR/zmAuhNshB5oYuwvs+8LbVP -GhoTIvNK2X95vabrc16xFGNQR4PDGhlNkI6WCPW0nAyQToKrX9szSsszZuwowATR -wvRn+c5g3iZxia861+AaxNwgraC6GF2N42qXvU0CggEAXD+NyUahEpSARRqVSOpL -K/q7pPOjS+TKOYJILv1tXZ3Av10OCOEqilwO4RMyXyOVSZ+mFTXSPfESh7iNweq9 -ajax/eRoeDVcyuUWaJ+MJMd1q2mOyClxNNDV6ERuNgdRqYEnUoSNPWLdEf48898d -c2HHfl9evsSyqnbCBC8SwFYaE3Hv4FFjrmqCogMiy/wXWQc4KiJoRxzGascvYyiN -iRnINmMrdv4KnQFiOR03+vzOk3kxyUKOouPAnN4Ahs2WAj0bPqBuV1XH1ZCqUO0s -6BHmyAEJD9Nka2Fa9bNGLI2yEhDERe40NM8wdI5FDUng1xp0dlOKuwOCNYLTrY4E -UQKCAQByK/e9bFaNv+BS81flfTt9tinKRIFc8IAKUl39M5wmehUqey8BGfKkMTGX -1w7R7lfCxoDi5Cl64fkPLWHrvZTuWh5ApC8r6uVjEX3TNWhBCQAB2tJmF7s9N73K -ymoh3VvQUHFZ2+IrCTgkJTWqjEdhPiiU3/oBnIv9ZYWf1ORkVhoAdxoLBn2XuTRC -xIKhiQeqCcKE9yTN26rt+7DjhB5TJ0W2meC8Rxb4lZRDD50MZayZQ6Vo4O87INpD -WjR7NdZndxUeinCPNQos9hEEke1ncCIzkwzJ9kn1R3iJzZRdjDKW3oT4G6QaStf5 -HUGWsrhnzvWoCOV+9+MdApoim8FI ------END PRIVATE KEY----- diff --git a/examples/self-signed-certs/reload/cert.pem b/examples/self-signed-certs/reload/cert.pem deleted file mode 100644 index 545c0ae..0000000 --- a/examples/self-signed-certs/reload/cert.pem +++ /dev/null @@ -1,32 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFkTCCA3mgAwIBAgIUXN/Uw2uyZ6/Uj4LRuuK0/RdRHRswDQYJKoZIhvcNAQEL -BQAwWDELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM -GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDERMA8GA1UEAwwIcmVsb2FkZWQwHhcN -MjEwODI5MTI1NjU2WhcNMjIwODI5MTI1NjU2WjBYMQswCQYDVQQGEwJVUzETMBEG -A1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkg -THRkMREwDwYDVQQDDAhyZWxvYWRlZDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCC -AgoCggIBAL/OR5KG8PqJgZSDza1lBVpZjW3jw1MA9eegePoK/4dYjd0Mdw+DeYOu -J/UmXoLHUDi/YWwZSmeY3YW0Wimwo1C5VqQL3GapSyFibvyTFE2fpoK0QtlgTKJ4 -G0mzdZ9NjibhvK23UOW5VbzlBujrYAaF2ynUha/cgVZ9uzvdwd6ooi+1i6XfHnkG -AQqGi6u/SIB+eHXn0w+tTYXmMp44jqIkjsK2vPNeifWj3MQxvgg7JTR/AKTmFCMm -BJIEP62BTFEnHJF+pRd2Hj0GIAiNBq1uA1F+HoUhxyX3OWHYCkRwPMnrSbPQOyxO -g4oFaUzAvMd2lHN/GjJS0kLwDy7WF/iXZuFxdEsmEmH62fE7N4P2uEnNw5OcHS82 -8Mc2EoMrV8zUBl4ZJ2eFo6w9lAx2bzMZyGXdOHsZWnJ5+1co6gfRfv51TeJGQx8f -JaHWFrn55qKBQmgQpKmCt/sG3HrqTviw1PtecsrzTliEXPoWdx6AhYaV+I4u8c8S -Q0NfdfjXx+5EMFDe5CvfWp/D5C1AQIV5E0Ao3Q+VfjoU/2tz9WcE5voHfyl3mBMI -FHvAPCZC18E+ZpiYyhRJLxP4z0MzTiuxp25lRi0Yt/5QTzEzFfH1UNQYe2xljPtf -syg5RtHoijcL+MncE1NUXz+B/qC4uJm8llPjFoL94Yg3/dwWOPtPAgMBAAGjUzBR -MB0GA1UdDgQWBBThri9Jq8CLFMEHJ+wE7WsCODVRwjAfBgNVHSMEGDAWgBThri9J -q8CLFMEHJ+wE7WsCODVRwjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA -A4ICAQCiR8yJ2YQyJfYDd9BT9eb9H8/S+Yz/9ayNS3zSJk4StQZaS1V6XjexzDBr -MRSr/hHGtO9G2qeocuJ/ArUJS5yYsf69g9AjuB+b41k0E4BVpiB/lENAhMbMbl+D -+ysRifUR2svHnZzKnL7DRrpS3vEUQhO37GXwbEi192rXAr2N6VE0LhxGyE7EwCzw -7gNkzoB3/Y4Fb+6zCYZorg3PmPZHrfu9vGFiP9nh+JVos9aq2JHZgZJ2N5Hcdh1H -Bci372+i1SHKfYutXrcSnUcPd4UgGQt6F63fOFHJEGsSVbHpJujqjpIscuPqgfn8 -DSkm9SEyVEV8MrY2vtwtVFOre4yjsaZ2fHDU7rCXOO88kIBBdvIpdIO4mBKV14ug -k9M1xzqK/KvgMUztuw/oLxOp7Vnii9sQ9bjzjbFEMiJ07V5Egr88Zh+VnN3ED1MH -Ri6Ho/CI/ttAwzZVhrKumOb6AprPVUteZFedpV80UaYmIthkeW0i9QcUOMkr4bL3 -gCghJeBSETTGEYCKOpcIFbvXwlc8d3KlL0Fa4EbQiw5vlPY28UChnxuZ3I0Vtetf -2F+3bLoVxfZD2Gc7p5bjGHgzUbGLFM4GgqQ6EbRh261Om9/bUxBao7mhKa23XWna -3Y4qISAqus6OolerflYJCCuWUF4N6e6fES5bqnZD49qAaIEg0A== ------END CERTIFICATE----- diff --git a/examples/self-signed-certs/reload/key.pem b/examples/self-signed-certs/reload/key.pem deleted file mode 100644 index bc7dccf..0000000 --- a/examples/self-signed-certs/reload/key.pem +++ /dev/null @@ -1,52 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQC/zkeShvD6iYGU -g82tZQVaWY1t48NTAPXnoHj6Cv+HWI3dDHcPg3mDrif1Jl6Cx1A4v2FsGUpnmN2F -tFopsKNQuVakC9xmqUshYm78kxRNn6aCtELZYEyieBtJs3WfTY4m4bytt1DluVW8 -5Qbo62AGhdsp1IWv3IFWfbs73cHeqKIvtYul3x55BgEKhourv0iAfnh159MPrU2F -5jKeOI6iJI7CtrzzXon1o9zEMb4IOyU0fwCk5hQjJgSSBD+tgUxRJxyRfqUXdh49 -BiAIjQatbgNRfh6FIccl9zlh2ApEcDzJ60mz0DssToOKBWlMwLzHdpRzfxoyUtJC -8A8u1hf4l2bhcXRLJhJh+tnxOzeD9rhJzcOTnB0vNvDHNhKDK1fM1AZeGSdnhaOs -PZQMdm8zGchl3Th7GVpyeftXKOoH0X7+dU3iRkMfHyWh1ha5+eaigUJoEKSpgrf7 -Btx66k74sNT7XnLK805YhFz6FncegIWGlfiOLvHPEkNDX3X418fuRDBQ3uQr31qf -w+QtQECFeRNAKN0PlX46FP9rc/VnBOb6B38pd5gTCBR7wDwmQtfBPmaYmMoUSS8T -+M9DM04rsaduZUYtGLf+UE8xMxXx9VDUGHtsZYz7X7MoOUbR6Io3C/jJ3BNTVF8/ -gf6guLiZvJZT4xaC/eGIN/3cFjj7TwIDAQABAoICAGNoV7PbeB2BEsWUIg8R4lpX -O3OOrfbg8pGfm9OLy6+r96pvAW3q6BmVM2RdBHKnNi6TEbzixqs2kOjw9iHRSHNX -+01+UDZs22FsELWazNUGP1hScKsUu+MgeJQUDIwJt/jy2cT201icW5FQ6enhw5zd -1x6w5LCmien3tAhtAEOUBqrPXpcTMknrELMR1GWo97yQz4HcKolfemRBUE6sZVAn -vk2wQ/GmN741tP+CAElnzfqNMBpGnH0zAP9kcFRORO1yZd4KUyn7r+RUvllwLdvI -vrOHt+2r+fj1TqolO/0IZpkH9uTYsTJfZtEryM1cvvppvLq3Ty5xukOzA0t07mqk -6G6217EhPSKE+DdBbsrExJjdrzBMyTQEL2qGLihhIFpDAd8WdNr8DRJrI4ZEo1Rg -Du1PuvcCscp97eTaiXSQTknUwBzHbeIkYepQYOksd+11cBXY40TR9X78LwUnfmBZ -yeAqFIBND5Z56NgPkXZ9DTeLyt6fkA9+V7WLfpxeGAdhn/JsyflIy2SQyFmRElxV -AC5/8GHgwTXjHmBJNg/PJZBHduje7BWPoCdX8X+SzE/ph/s6vzNdYsGxUFgoMshj -YlhTS9NL0Asp+KQD+bsMYxYmhvb++YIIqwdkMAP4sGD3iKFQXRRRUzldXC5A88US -1Zk0xEvYjw7F5GEKi35RAoIBAQDgH/C07vP1+qPHj3W6vOQ90T2WbS4kfpWUv5wc -KKyvZVDqBrx6R22/fn1GrdXKxrMzVIFN0AXx38NYUmUVe9tQ/nq2Lx6PFKWX5khw -84IJw0LLuXBN6NiorxV4Ep9Bf0uST81sPMmE1vDyAveUVC+FX8NAgD8Hr4tDsleF -NIijqDjVbAN6+T5qlUyuUSjSUo+KnWJ72M2PCSiUDONW93kACk77wo1Hon2YcO3H -IyAQnPJKPYNlgivm5EmEvvThJ2nmlaXwadSH9bNes8RkzcfPJybkVEFMD9nxD127 -DnuHpRBFkjGfPsb9ulLODPvfQirSSXQsR1N8hQTZACd9g6L9AoIBAQDbFabN7Ztg -CnMZ9hT8qEvau67Q8KmpaZBptuYM/W3/T4oxoPOTLZCzvVX5Xy+hOuec/N/DAP/4 -6PDTXPt6kEr31ewcQyBVQarB9bkY1t9iMa32ZsVBe00/UFrdgR3MwD3jP3pmuifT -+ZI4MyJqq4SGek7Zqjc6Unn24TSqXVsvbtILqTbRsqf5iV2LUx3NmqbX84K4EwBm -ZPrMyD0jiAd0YibewyorbhDVTKVxPtVLVCCQpLaTcvkYs1H3mSaY2yB4nBaWUto7 -3iRW497KOpsBpx4UeW4iNni9JtfPKALdIaz+X4ig7tyxwRuMVUkKd7q7faM8IGoH -45xH8w5mW4c7AoIBAQCWRhQ43LcKyOEjnxcK/Df1EuS+hboYkh9tOwRLBSKz/7S/ -FYEuY9I8QW1yBICCk7P3yMNiDwbNZIEwKR7JxuAIcHiKyxEsUmWtcaREx6D7NscE -nfOk6WjLwYkdly7c1aMwGP3dguyDezLWshKai8/JF6ptBxA78QHphByWneC4CsUA -pIm43IFzKWPexWAflWfVQy2TaIx7SWLB0dpkp02kL0VCHPJpg5O+sIldqjmHqhPy -n0gIub0B9TMuJHNAvBKPnutCRVNRTfbUmqgmBqvgQ5oaIjwd6crxjKIGF/HPw2cj -nqBS6960pUd8DMycp1ra4JFaVwCtTusvLKFN0QNpAoIBAAJ128m0QWpys5g3C0VL -Ho72TKBME5uzc8u8IhlDP1j+q66jABlHCbj7B1wllYNaBf/dVyX5fOZut0WoZaqa -tDzUSjKHDnXmpuRGvi1pPFj99dYukUiK+fMcE+ko6gzCm+9RZy6AKLJYuyumZ1yL -UJGyDfCj2Lru8i+zl8PSCJQfynwXCmaQexJyWHqYFF2avwTt1yn6DKcZuzdRiF49 -yNelwon95xtVwRqkIbeD3SFbcIIvV12QjPuaB/Gf5q8QxuyT1C0cARdrBz1yka3z -uonqNoxEUNhRhEmbhhDtghq5phe1OvOTuybD5GtPCeL0NUSlxI+ITaiJBdhJAoBj -xsECggEAKN8pJSYAScGx94fCwNbMBxMHqH3Kk83W+DF1V0ejvfhCmWZ6Vf/81xqz -a22AtpKA0EQIV/+d+4BvddMvLtKgYpYf9YR0MTyaps7DIzebr352/0WLlPZTWr5B -mzwWCtiBL0R7i6bXIiuXxqZv7zjFlXHRcj4GQI0zHT61CLGkTlF5f/25js0NkL+K -dizoG4pOA0mvZKJIKdE1GI3/20qP01BoCIHVdRHUdB0yKhoHi1EuO7hZdAPN9gsB -LMYbHG3f/dtvj0KCscKYB/Py/SmPdTW+xPZAf7tCrqZjQhPvqP2cD1UQ4glr+N2a -85DaC33fAFuevGxpS147+sAiW6doqQ== ------END PRIVATE KEY----- diff --git a/examples/shutdown.rs b/examples/shutdown.rs deleted file mode 100644 index b00943a..0000000 --- a/examples/shutdown.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! Run with `cargo run --example shutdown` command. -//! -//! To connect through browser, navigate to "http://localhost:3000" url. -//! -//! Server will shutdown in 20 seconds. - -use axum::{routing::get, Router}; -use hyper_server::Handle; -use std::{net::SocketAddr, time::Duration}; -use tokio::time::sleep; - -#[tokio::main] -async fn main() { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let handle = Handle::new(); - - // Spawn a task to shutdown server. - tokio::spawn(shutdown(handle.clone())); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - println!("listening on {}", addr); - hyper_server::bind(addr) - .handle(handle) - .serve(app.into_make_service()) - .await - .unwrap(); - - println!("server is shut down"); -} - -async fn shutdown(handle: Handle) { - // Wait 20 seconds. - sleep(Duration::from_secs(20)).await; - - println!("sending shutdown signal"); - - // Signal the server to shutdown using Handle. - handle.shutdown(); -} diff --git a/src/accept.rs b/src/accept.rs deleted file mode 100644 index 21fc431..0000000 --- a/src/accept.rs +++ /dev/null @@ -1,62 +0,0 @@ -//! Module `accept` provides utilities for asynchronously processing and modifying IO streams and services. -//! -//! The primary trait exposed by this module is [`Accept`], which allows for asynchronous transformations -//! of input streams and services. The module also provides a default implementation, [`DefaultAcceptor`], -//! that performs no modifications and directly passes through the input stream and service. - -use std::{ - future::{Future, Ready}, - io, -}; - -/// An asynchronous trait for processing and modifying IO streams and services. -/// -/// Implementations of this trait can be used to modify or transform the input stream and service before -/// further processing. For instance, this trait could be used to perform initial authentication, logging, -/// or other setup operations on new connections. -pub trait Accept { - /// The modified or transformed IO stream produced by `accept`. - type Stream; - - /// The modified or transformed service produced by `accept`. - type Service; - - /// The Future type that is returned by `accept`. - type Future: Future>; - - /// Asynchronously process and possibly modify the given IO stream and service. - /// - /// # Parameters: - /// * `stream`: The incoming IO stream, typically a connection. - /// * `service`: The associated service with the stream. - /// - /// # Returns: - /// A future resolving to the modified stream and service, or an error. - fn accept(&self, stream: I, service: S) -> Self::Future; -} - -/// A default implementation of the [`Accept`] trait that performs no modifications. -/// -/// This is a no-op acceptor that simply passes the provided stream and service through without any transformations. -#[derive(Clone, Copy, Debug, Default)] -pub struct DefaultAcceptor; - -impl DefaultAcceptor { - /// Create a new default acceptor instance. - /// - /// # Returns: - /// An instance of [`DefaultAcceptor`]. - pub fn new() -> Self { - Self - } -} - -impl Accept for DefaultAcceptor { - type Stream = I; - type Service = S; - type Future = Ready>; - - fn accept(&self, stream: I, service: S) -> Self::Future { - std::future::ready(Ok((stream, service))) - } -} diff --git a/src/addr_incoming_config.rs b/src/addr_incoming_config.rs deleted file mode 100644 index 250dab6..0000000 --- a/src/addr_incoming_config.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::time::Duration; - -/// Configuration settings for the `AddrIncoming`. -/// -/// This configuration structure is designed to be used in conjunction with the -/// [`AddrIncoming`](hyper::server::conn::AddrIncoming) type from the Hyper crate. -/// It provides a mechanism to customize server settings like TCP keepalive probes, -/// error handling, and other TCP socket-level configurations. -#[derive(Debug, Clone)] -pub struct AddrIncomingConfig { - pub(crate) tcp_sleep_on_accept_errors: bool, - pub(crate) tcp_keepalive: Option, - pub(crate) tcp_keepalive_interval: Option, - pub(crate) tcp_keepalive_retries: Option, - pub(crate) tcp_nodelay: bool, -} - -impl Default for AddrIncomingConfig { - fn default() -> Self { - Self::new() - } -} - -impl AddrIncomingConfig { - /// Creates a new `AddrIncomingConfig` with default settings. - /// - /// # Default Settings - /// - Sleep on accept errors: `true` - /// - TCP keepalive probes: Disabled (`None`) - /// - Duration between keepalive retransmissions: None - /// - Number of keepalive retransmissions: None - /// - `TCP_NODELAY` option: `false` - /// - /// # Returns - /// - /// A new `AddrIncomingConfig` instance with default settings. - pub fn new() -> AddrIncomingConfig { - Self { - tcp_sleep_on_accept_errors: true, - tcp_keepalive: None, - tcp_keepalive_interval: None, - tcp_keepalive_retries: None, - tcp_nodelay: false, - } - } - - /// Creates a cloned copy of the current configuration. - /// - /// This method can be useful when you want to preserve the original settings and - /// create a modified configuration based on the current one. - /// - /// # Returns - /// - /// A cloned `AddrIncomingConfig`. - pub fn build(&mut self) -> Self { - self.clone() - } - - /// Specifies whether to pause (sleep) when an error occurs while accepting a connection. - /// - /// This can be useful to prevent rapidly exhausting file descriptors in scenarios - /// where errors might be transient or frequent. - /// - /// # Parameters - /// - /// - `val`: Whether to sleep on accept errors. Default is `true`. - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_sleep_on_accept_errors(&mut self, val: bool) -> &mut Self { - self.tcp_sleep_on_accept_errors = val; - self - } - - /// Configures the frequency of TCP keepalive probes. - /// - /// TCP keepalive probes are used to detect whether a peer is still connected. - /// - /// # Parameters - /// - /// - `val`: Duration between keepalive probes. Setting to `None` disables keepalive probes. Default is `None`. - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_keepalive(&mut self, val: Option) -> &mut Self { - self.tcp_keepalive = val; - self - } - - /// Configures the duration between two successive TCP keepalive retransmissions. - /// - /// If an acknowledgment to a previous keepalive probe isn't received within this duration, - /// a new probe will be sent. - /// - /// # Parameters - /// - /// - `val`: Duration between keepalive retransmissions. Default is no interval (`None`). - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_keepalive_interval(&mut self, val: Option) -> &mut Self { - self.tcp_keepalive_interval = val; - self - } - - /// Configures the number of times to retransmit a TCP keepalive probe if no acknowledgment is received. - /// - /// After the specified number of retransmissions, the remote end is considered unavailable. - /// - /// # Parameters - /// - /// - `val`: Number of retransmissions before considering the remote end unavailable. Default is no retry (`None`). - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_keepalive_retries(&mut self, val: Option) -> &mut Self { - self.tcp_keepalive_retries = val; - self - } - - /// Configures the `TCP_NODELAY` option for accepted connections. - /// - /// When enabled, this option disables Nagle's algorithm, which can reduce latencies for small packets. - /// - /// # Parameters - /// - /// - `val`: Whether to enable `TCP_NODELAY`. Default is `false`. - /// - /// # Returns - /// - /// A mutable reference to the current `AddrIncomingConfig`. - pub fn tcp_nodelay(&mut self, val: bool) -> &mut Self { - self.tcp_nodelay = val; - self - } -} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..78d89a5 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,89 @@ +use std::{error::Error as StdError, fmt}; + +/// A type alias for the source of an error, which is a boxed trait object. +/// This allows for dynamic dispatch and type erasure of the original error type. +type Source = Box; + +/// Represents errors that originate from the server. +/// This struct provides a public API for error handling. +pub struct Error { + inner: ErrorImpl, +} + +/// The internal implementation of the Error struct. +/// This separation allows for better control over the public API. +struct ErrorImpl { + kind: Kind, + source: Option, +} + +/// Enum representing different kinds of errors that can occur. +/// Currently, only includes a Transport variant, but can be extended for more error types. +#[derive(Debug)] +pub enum Kind { + Transport, +} + +impl Error { + /// Creates a new Error with a specific kind. + pub fn new(kind: Kind) -> Self { + Self { + inner: ErrorImpl { kind, source: None }, + } + } + + /// Attaches a source error to this Error. + /// This method consumes self and returns a new Error, allowing for method chaining. + pub fn with(mut self, source: impl Into) -> Self { + self.inner.source = Some(source.into()); + self + } + + /// Creates a new Transport Error with the given source. + /// This is a convenience method combining new() and with(). + pub fn from_source(source: impl Into) -> Self { + Error::new(Kind::Transport).with(source) + } + + /// Returns a string slice describing the error. + fn description(&self) -> &str { + match &self.inner.kind { + Kind::Transport => "transport error", + } + } +} + +/// Implements the Debug trait for Error. +/// This provides a custom debug representation of the error. +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_tuple("tonic::transport::Error"); + + f.field(&self.inner.kind); + + if let Some(source) = &self.inner.source { + f.field(source); + } + + f.finish() + } +} + +/// Implements the Display trait for Error. +/// This provides a user-friendly string representation of the error. +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.description()) + } +} + +/// Implements the std::error::Error trait for Error. +/// This allows our custom Error to be used with the standard error handling mechanisms in Rust. +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.inner + .source + .as_ref() + .map(|source| &**source as &(dyn StdError + 'static)) + } +} diff --git a/src/fuse.rs b/src/fuse.rs new file mode 100644 index 0000000..944c1c0 --- /dev/null +++ b/src/fuse.rs @@ -0,0 +1,42 @@ +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// `Fuse` is a wrapper around a future that ensures it can only complete once. +/// After the wrapped future completes, all subsequent polls will return `Poll::Pending`. +/// +/// This struct is borrowed from the `futures-util` crate and is used in the hyper server. +/// LICENSE: MIT or Apache-2.0 +#[pin_project] +pub(crate) struct Fuse { + /// The wrapped future. Once it completes, this will be set to `None`. + #[pin] + pub(crate) inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + /// The output type is the same as the wrapped future's output type. + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Match on the pinned inner future + match self.as_mut().project().inner.as_pin_mut() { + // If we have a future, poll it + Some(fut) => fut.poll(cx).map(|output| { + // If the future completes, set inner to None + // This ensures that future calls to poll will return Poll::Pending + self.project().inner.set(None); + output + }), + // If inner is None, it means the future has already completed, + // So we return Poll::Pending + // and yes, this is confusing and counterintuitive naming, + // but apparently this is how it's done. + None => Poll::Pending, + } + } +} diff --git a/src/handle.rs b/src/handle.rs deleted file mode 100644 index bd5269c..0000000 --- a/src/handle.rs +++ /dev/null @@ -1,167 +0,0 @@ -use crate::notify_once::NotifyOnce; -use std::{ - net::SocketAddr, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex, - }, - time::Duration, -}; -use tokio::{sync::Notify, time::sleep}; - -/// A handle to manage and interact with the server. -/// -/// `Handle` provides methods to access server information, such as the number of active connections, -/// and to perform actions like initiating a shutdown. -#[derive(Clone, Debug, Default)] -pub struct Handle { - inner: Arc, -} - -#[derive(Debug, Default)] -struct HandleInner { - addr: Mutex>, - addr_notify: Notify, - conn_count: AtomicUsize, - shutdown: NotifyOnce, - graceful: NotifyOnce, - graceful_dur: Mutex>, - conn_end: NotifyOnce, -} - -impl Handle { - /// Create a new handle for the server. - /// - /// # Returns - /// - /// A new `Handle` instance. - pub fn new() -> Self { - Self::default() - } - - /// Get the number of active connections to the server. - /// - /// # Returns - /// - /// The number of active connections. - pub fn connection_count(&self) -> usize { - self.inner.conn_count.load(Ordering::SeqCst) - } - - /// Initiate an immediate shutdown of the server. - /// - /// This method will terminate the server without waiting for active connections to close. - pub fn shutdown(&self) { - self.inner.shutdown.notify_waiters(); - } - - /// Initiate a graceful shutdown of the server. - /// - /// The server will wait for active connections to close before shutting down. If a duration - /// is provided, the server will wait up to that duration for active connections to close - /// before forcing a shutdown. - /// - /// # Parameters - /// - /// - `duration`: Maximum time to wait for active connections to close. `None` means the server - /// will wait indefinitely. - pub fn graceful_shutdown(&self, duration: Option) { - *self.inner.graceful_dur.lock().unwrap() = duration; - self.inner.graceful.notify_waiters(); - } - - /// Wait until the server starts listening and then returns its local address and port. - /// - /// # Returns - /// - /// The local `SocketAddr` if the server successfully binds, otherwise `None`. - pub async fn listening(&self) -> Option { - let notified = self.inner.addr_notify.notified(); - - if let Some(addr) = *self.inner.addr.lock().unwrap() { - return Some(addr); - } - - notified.await; - - *self.inner.addr.lock().unwrap() - } - - /// Internal method to notify the handle when the server starts listening on a particular address. - pub(crate) fn notify_listening(&self, addr: Option) { - *self.inner.addr.lock().unwrap() = addr; - self.inner.addr_notify.notify_waiters(); - } - - /// Creates a watcher that monitors server status and connection activity. - pub(crate) fn watcher(&self) -> Watcher { - Watcher::new(self.clone()) - } - - /// Internal method to wait until the server is shut down. - pub(crate) async fn wait_shutdown(&self) { - self.inner.shutdown.notified().await; - } - - /// Internal method to wait until the server is gracefully shut down. - pub(crate) async fn wait_graceful_shutdown(&self) { - self.inner.graceful.notified().await; - } - - /// Internal method to wait until all connections have ended, or the optional graceful duration has expired. - pub(crate) async fn wait_connections_end(&self) { - if self.inner.conn_count.load(Ordering::SeqCst) == 0 { - return; - } - - let deadline = *self.inner.graceful_dur.lock().unwrap(); - - match deadline { - Some(duration) => tokio::select! { - biased; - _ = sleep(duration) => self.shutdown(), - _ = self.inner.conn_end.notified() => (), - }, - None => self.inner.conn_end.notified().await, - } - } -} - -/// A watcher that monitors server status and connection activity. -/// -/// The watcher keeps track of active connections and listens for shutdown or graceful shutdown signals. -pub(crate) struct Watcher { - handle: Handle, -} - -impl Watcher { - /// Creates a new watcher linked to the given server handle. - fn new(handle: Handle) -> Self { - handle.inner.conn_count.fetch_add(1, Ordering::SeqCst); - Self { handle } - } - - /// Internal method to wait until the server is gracefully shut down. - pub(crate) async fn wait_graceful_shutdown(&self) { - self.handle.wait_graceful_shutdown().await - } - - /// Internal method to wait until the server is shut down. - pub(crate) async fn wait_shutdown(&self) { - self.handle.wait_shutdown().await - } -} - -impl Drop for Watcher { - /// Reduces the active connection count when a watcher is dropped. - /// - /// If the connection count reaches zero and a graceful shutdown has been initiated, the server is notified that - /// all connections have ended. - fn drop(&mut self) { - let count = self.handle.inner.conn_count.fetch_sub(1, Ordering::SeqCst) - 1; - - if count == 0 && self.handle.inner.graceful.is_notified() { - self.handle.inner.conn_end.notify_waiters(); - } - } -} diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..b497917 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,975 @@ +use crate::io::Transport; +use std::future::pending; +use std::{future::Future, pin::pin, sync::Arc}; +use tokio_rustls::TlsAcceptor; + +use bytes::Bytes; +use http::{Request, Response}; +use http_body::Body; +use hyper::body::Incoming; +use hyper::service::Service; +use hyper_util::rt::TokioTimer; +use hyper_util::{ + rt::TokioIo, + server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec}, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::sleep; +use tokio::time::Duration; +use tokio_stream::Stream; +use tokio_stream::StreamExt as _; +use tracing::{debug, trace}; + +use crate::fuse::Fuse; + +/// Sleeps for a specified duration or waits indefinitely. +/// +/// This function is used to implement timeouts or indefinite waiting periods. +/// +/// # Arguments +/// +/// * `wait_for` - An `Option` specifying how long to sleep. +/// If `None`, the function will wait indefinitely. +#[inline] +async fn sleep_or_pending(wait_for: Option) { + match wait_for { + Some(wait) => sleep(wait).await, + None => pending().await, + }; +} + +/// Serves HTTP an HTTP connection on the transport from a hyper service backend. +/// +/// This method handles an HTTP connection on a given transport `IO`, processing requests through +/// the provided service and managing the connection lifecycle. +/// +/// # Type Parameters +/// +/// * `B`: The body type for the HTTP response. +/// * `IO`: The I/O type for the HTTP connection. +/// * `S`: The service type that processes HTTP requests. +/// * `E`: The executor type for the HTTP server connection. +/// +/// # Arguments +/// +/// * `hyper_io`: The I/O object representing the inbound hyper IO stream. +/// * `hyper_service`: The hyper `Service` implementation used to process HTTP requests. +/// * `builder`: A `Builder` used to create and serve the HTTP connection. +/// * `watcher`: An optional `tokio::sync::watch::Receiver` for graceful shutdown signaling. +/// * `max_connection_age`: An optional `Duration` specifying the maximum age of the connection +/// before initiating a graceful shutdown. +#[inline] +pub async fn serve_http_connection( + hyper_io: IO, + hyper_service: S, + builder: HttpConnectionBuilder, + watcher: Option>, + max_connection_age: Option, +) where + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into> + Send + Sync, + IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, + E: HttpServerConnExec + Send + Sync + 'static, +{ + // Set up a fused future for the watcher + let mut watcher = watcher.clone(); + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + // Set up the sleep future for max connection age + let sleep = sleep_or_pending(max_connection_age); + tokio::pin!(sleep); + + // TODO(It's absolutely terrible that we have to clone the builder here) + // and configure it rather than passing it in. + // this is due to an API flaw in the hyper_util crate. + // this builder doesn't have a way to convert back to a builder + // once you start building. + + // Configure the builder + let mut builder = builder.clone(); + builder + // HTTP/1 settings + .http1() + // Enable half-close for better connection handling + .half_close(true) + // Enable keep-alive to reduce overhead for multiple requests + .keep_alive(true) + // Increase max buffer size to 256KB for better performance with larger payloads + .max_buf_size(256 * 1024) + // Enable immediate flushing of pipelined responses + .pipeline_flush(true) + // Preserve original header case for compatibility + .preserve_header_case(true) + // Disable automatic title casing of headers to reduce processing overhead + .title_case_headers(false) + // HTTP/2 settings + .http2() + // Add the timer to the builder + // This will cause you all sorts of pain otherwise + // https://github.com/seanmonstar/reqwest/issues/2421 + // https://github.com/rustls/hyper-rustls/issues/287 + .timer(TokioTimer::new()) + // Increase initial stream window size to 2MB for better throughput + .initial_stream_window_size(Some(2 * 1024 * 1024)) + // Increase initial connection window size to 4MB for improved performance + .initial_connection_window_size(Some(4 * 1024 * 1024)) + // Enable adaptive window for dynamic flow control + .adaptive_window(true) + // Increase max frame size to 32KB for larger data chunks + .max_frame_size(Some(32 * 1024)) + // Allow up to 2000 concurrent streams for better parallelism + .max_concurrent_streams(Some(2000)) + // Increase max send buffer size to 2MB for improved write performance + .max_send_buf_size(2 * 1024 * 1024) + // Enable CONNECT protocol support for proxying and tunneling + .enable_connect_protocol() + // Increase max header list size to 32KB to handle larger headers + .max_header_list_size(32 * 1024) + // Set keep-alive interval to 10 seconds for more responsive connection management + .keep_alive_interval(Some(Duration::from_secs(10))) + // Set keep-alive timeout to 30 seconds to balance connection reuse and resource conservation + .keep_alive_timeout(Duration::from_secs(30)); + + // Create and pin the HTTP connection + // + // This handles all the HTTP connection logic via hyper. + // This is a pointer to a blocking task, effectively + // Which tells us how it's doing via the hyper_io transport. + let mut conn = pin!(builder.serve_connection_with_upgrades(hyper_io, hyper_service)); + + // Here we wait for the http connection to terminate + loop { + tokio::select! { + // Handle the connection result + rv = &mut conn => { + if let Err(err) = rv { + // Log any errors that occur while serving the HTTP connection + debug!("failed serving HTTP connection: {:#}", err); + } + break; + }, + // Handle max connection age timeout + _ = &mut sleep => { + // Initiate a graceful shutdown when max connection age is reached + conn.as_mut().graceful_shutdown(); + sleep.set(sleep_or_pending(None)); + }, + // Handle graceful shutdown signal + _ = &mut sig => { + // Initiate a graceful shutdown when signal is received + conn.as_mut().graceful_shutdown(); + } + } + } + + trace!("HTTP connection closed"); +} + +/// Serves HTTP/HTTPS requests with graceful shutdown capability. +/// +/// This function sets up an HTTP/HTTPS server that can handle incoming connections and +/// process requests using the provided service. It supports both plain HTTP and HTTPS +/// connections, as well as graceful shutdown. +/// +/// # Type Parameters +/// +/// * `E`: The executor type for the HTTP server connection. +/// * `F`: The future type for the shutdown signal. +/// * `I`: The incoming stream of IO objects. +/// * `IO`: The I/O type for the HTTP connection. +/// * `IE`: The error type for the incoming stream. +/// * `ResBody`: The response body type. +/// * `S`: The service type that processes HTTP requests. +/// +/// # Arguments +/// +/// * `service`: The service used to process HTTP requests. +/// * `incoming`: The stream of incoming connections. +/// * `builder`: The `HttpConnectionBuilder` used to configure the server. +/// * `tls_config`: An optional TLS configuration for HTTPS support. +/// * `signal`: An optional future that, when resolved, signals the server to shut down gracefully. +/// +/// # Returns +/// +/// A `Result` indicating success or failure of the server operation. +/// +/// # Examples +/// +/// These examples provide some very basic ways to use the server. With that said, +/// the server is very flexible and can be used in a variety of ways. This is +/// because you as the integrator have control over every level of the stack at +/// construction, with all the native builders exposed via generics. +/// +/// Setting up an HTTP server with graceful shutdown: +/// +/// ```rust,no_run +/// use std::convert::Infallible; +/// use bytes::Bytes; +/// use http_body_util::Full; +/// use hyper::body::Incoming; +/// use hyper::{Request, Response}; +/// use hyper_util::rt::TokioExecutor; +/// use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +/// use tokio::net::TcpListener; +/// use tokio_stream::wrappers::TcpListenerStream; +/// use tower::ServiceBuilder; +/// use std::net::SocketAddr; +/// +/// use hyper_server::serve_http_with_shutdown; +/// +/// async fn hello(_: Request) -> Result>, Infallible> { +/// Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +/// } +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() -> Result<(), Box> { +/// let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); +/// let listener = TcpListener::bind(addr).await?; +/// let incoming = TcpListenerStream::new(listener); +/// +/// let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); +/// +/// let builder = HttpConnectionBuilder::new(TokioExecutor::new()); +/// let svc = hyper::service::service_fn(hello); +/// let svc = ServiceBuilder::new().service(svc); +/// +/// tokio::spawn(async move { +/// // Simulate a shutdown signal after 60 seconds +/// tokio::time::sleep(std::time::Duration::from_secs(60)).await; +/// let _ = shutdown_tx.send(()); +/// }); +/// +/// serve_http_with_shutdown( +/// svc, +/// incoming, +/// builder, +/// None, // No TLS config for plain HTTP +/// Some(async { +/// shutdown_rx.await.ok(); +/// }), +/// ).await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// Setting up an HTTPS server: +/// +/// ```rust,no_run +/// use std::convert::Infallible; +/// use std::sync::Arc; +/// use bytes::Bytes; +/// use http_body_util::Full; +/// use hyper::body::Incoming; +/// use hyper::{Request, Response}; +/// use hyper_util::rt::TokioExecutor; +/// use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +/// use tokio::net::TcpListener; +/// use tokio_stream::wrappers::TcpListenerStream; +/// use tower::ServiceBuilder; +/// use rustls::ServerConfig; +/// use std::io; +/// use std::net::SocketAddr; +/// use std::future::Future; +/// +/// use hyper_server::{serve_http_with_shutdown, load_certs, load_private_key}; +/// +/// async fn hello(_: Request) -> Result>, Infallible> { +/// Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +/// } +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() -> Result<(), Box> { +/// let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); +/// let listener = TcpListener::bind(addr).await?; +/// let incoming = TcpListenerStream::new(listener); +/// +/// let builder = HttpConnectionBuilder::new(TokioExecutor::new()); +/// let svc = hyper::service::service_fn(hello); +/// let svc = ServiceBuilder::new().service(svc); +/// +/// // Set up TLS config +/// let certs = load_certs("examples/sample.pem")?; +/// let key = load_private_key("examples/sample.rsa")?; +/// +/// let config = ServerConfig::builder() +/// .with_no_client_auth() +/// .with_single_cert(certs, key) +/// .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; +/// let tls_config = Arc::new(config); +/// +/// serve_http_with_shutdown( +/// svc, +/// incoming, +/// builder, +/// Some(tls_config), +/// Some(std::future::pending::<()>()), // A never-resolving future as a placeholder +/// ).await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// Setting up an HTTPS server with a Tower service: +/// +/// ```rust,no_run +/// use std::convert::Infallible; +/// use std::sync::Arc; +/// use bytes::Bytes; +/// use http_body_util::Full; +/// use hyper::body::Incoming; +/// use hyper::{Request, Response}; +/// use hyper_util::rt::TokioExecutor; +/// use hyper_util::server::conn::auto::Builder as HttpConnectionBuilder; +/// use hyper_util::service::TowerToHyperService; +/// use tokio::net::TcpListener; +/// use tokio_stream::wrappers::TcpListenerStream; +/// use tower::{ServiceBuilder, ServiceExt}; +/// use rustls::ServerConfig; +/// use std::io; +/// use std::net::SocketAddr; +/// use std::future::Future; +/// +/// use hyper_server::{serve_http_with_shutdown, load_certs, load_private_key}; +/// +/// async fn hello(_: Request) -> Result>, Infallible> { +/// Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) +/// } +/// +/// #[tokio::main(flavor = "current_thread")] +/// async fn main() -> Result<(), Box> { +/// let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); +/// let listener = TcpListener::bind(addr).await?; +/// let incoming = TcpListenerStream::new(listener); +/// +/// let builder = HttpConnectionBuilder::new(TokioExecutor::new()); +/// +/// // Set up the Tower service +/// let svc = tower::service_fn(hello); +/// let svc = ServiceBuilder::new() +/// .service(svc); +/// +/// // Convert the Tower service to a Hyper service +/// let svc = TowerToHyperService::new(svc); +/// +/// // Set up TLS config +/// let certs = load_certs("examples/sample.pem")?; +/// let key = load_private_key("examples/sample.rsa")?; +/// +/// let config = ServerConfig::builder() +/// .with_no_client_auth() +/// .with_single_cert(certs, key) +/// .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; +/// let tls_config = Arc::new(config); +/// +/// serve_http_with_shutdown( +/// svc, +/// incoming, +/// builder, +/// Some(tls_config), +/// Some(std::future::pending::<()>()), // A never-resolving future as a placeholder +/// ).await?; +/// +/// Ok(()) +/// } +/// ``` +/// +/// # Notes +/// +/// - The server will continue to accept new connections until the `signal` future resolves. +/// - When using TLS, make sure to provide a properly configured `ServerConfig`. +/// - The function will return when all connections have been closed after the shutdown signal. +#[inline] +pub async fn serve_http_with_shutdown( + service: S, + incoming: I, + builder: HttpConnectionBuilder, + tls_config: Option>, + signal: Option, +) -> Result<(), super::Error> +where + F: Future + Send + 'static, + I: Stream> + Send + 'static, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, + ResBody: Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync, + E: HttpServerConnExec + Send + Sync + 'static, +{ + // Create a channel for signaling graceful shutdown to listening connections + let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); + let signal_tx = Arc::new(signal_tx); + + // We say that graceful shutdown is enabled if a signal is provided + let graceful = signal.is_some(); + + // The signal future that will resolve when the server should shut down + let mut sig = pin!(Fuse { inner: signal }); + + // Prepare the incoming stream of TCP connections + // from the provided stream of IO objects, which is coming + // most likely from a TCP stream. + let incoming = crate::tcp::serve_tcp_incoming(incoming); + + // Pin the incoming stream to the stack + let mut incoming = pin!(incoming); + + // Create TLS acceptor if TLS config is provided + let tls_acceptor = tls_config.map(TlsAcceptor::from); + + // Enter the main server loop + loop { + // Select between the future which returns first, + // A shutdown signal or an incoming IO result. + tokio::select! { + // Check if we received a graceful shutdown signal for the server + _ = &mut sig => { + // Exit the loop if we did, and shut down the server + trace!("signal received, shutting down"); + break; + }, + // Wait for the next IO result from the incoming stream + io = incoming.next() => { + // If we got an IO result from the incoming stream + // This effectively demultiplexes the incoming stream of IO objects, + // which each represent a connection which may then be individually + // streamed/handled. + // + // So this is effectively a demultiplexer for the incoming stream of IO objects. + // + // Because of the way the stream handling is implemented, + // the responses are multiplexed back over the same stream to the client. + // However, that would not be intuitive just from looking it this code + // because the reverse multiplexing is "invisible" to the reader. + let io = match io { + // We check if it's a valid stream + Some(Ok(io)) => io, + // or if it's a non-fatal error + Some(Err(e)) => { + trace!("error accepting connection: {:#}", e); + // if it's a non-fatal error, we continue processing IO objects + continue; + }, + None => { + // If we got a fatal error, meaning we lost connection or something else + // we break out of the loop + break + }, + }; + + trace!("TCP streaming connection accepted"); + + // For each of these TCP streams, we are going to want to + // spawn a new task to handle the connection. + + // Clone necessary values for the spawned task + let service = service.clone(); + let builder = builder.clone(); + let tls_acceptor = tls_acceptor.clone(); + let signal_rx = signal_rx.clone(); + + // Spawn a new task to handle this connection + tokio::spawn(async move { + // Abstract the transport layer for hyper + + let transport = if let Some(tls_acceptor) = &tls_acceptor { + // If TLS is enabled, then we perform a TLS handshake + // Clone the TLS acceptor and IO for use in the blocking task + let tls_acceptor = tls_acceptor.clone(); + let io = io; + + match tokio::task::spawn_blocking(move || { + // Perform the TLS handshake in a blocking task + // Because this is one of the most computationally heavy things the sever does. + // In the case of ECDSA and very fast handshakes, this has more downside + // than upside, but in the case of RSA and slow handshakes, this is a good idea. + // It amortizes out to about 2 µs of overhead per connection. + // and moves this computationally heavy task off the main thread pool. + tokio::runtime::Handle::current().block_on(tls_acceptor.accept(io)) + }).await { + // Handle the result of the TLS handshake + Ok(Ok(tls_stream)) => Transport::new_tls(tls_stream), + Ok(Err(e)) => { + // This connection failed to handshake + debug!("TLS handshake failed: {:#}", e); + return; + }, + Err(e) => { + // This connection was malformed and the server was unable to handle it + debug!("TLS handshake task panicked: {:#}", e); + return; + } + + } + } + else { + // If TLS is not enabled, then we use a plain transport + Transport::new_plain(io) + }; + + // Convert our abstracted tokio transport into a hyper transport + let hyper_io = TokioIo::new(transport); + + // Serve the HTTP connections on this transport + serve_http_connection( + hyper_io, + service, + builder, + graceful.then_some(signal_rx), + None + ).await; + }); + } + } + } + + // Handle graceful shutdown + if graceful { + // Broadcast the shutdown signal to all connections + let _ = signal_tx.send(()); + // Drop the sender to signal that no more connections will be accepted + drop(signal_rx); + trace!( + "waiting for {} connections to close", + signal_tx.receiver_count() + ); + + // Wait for all connections to close + // TODO(Add a timeout here, optionally) + signal_tx.closed().await; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{load_certs, load_private_key}; + use bytes::Bytes; + use http_body_util::{BodyExt, Empty, Full}; + use hyper::{body::Incoming, Request, Response, StatusCode}; + use hyper_util::rt::TokioExecutor; + use hyper_util::service::TowerToHyperService; + use rustls::ServerConfig; + use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + use tokio::net::{TcpListener, TcpStream}; + use tokio::sync::oneshot; + use tokio_stream::wrappers::TcpListenerStream; + + // Utility functions + + async fn echo(req: Request) -> Result>, hyper::Error> { + match (req.method(), req.uri().path()) { + (&hyper::Method::GET, "/") => { + Ok(Response::new(Full::new(Bytes::from("Hello, World!")))) + } + (&hyper::Method::POST, "/echo") => { + let body = req.collect().await?.to_bytes(); + Ok(Response::new(Full::new(body))) + } + _ => { + let mut res = Response::new(Full::new(Bytes::from("Not Found"))); + *res.status_mut() = StatusCode::NOT_FOUND; + Ok(res) + } + } + } + + async fn setup_test_server(addr: SocketAddr) -> (TcpListenerStream, SocketAddr) { + let listener = TcpListener::bind(addr).await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + let incoming = TcpListenerStream::new(listener); + (incoming, server_addr) + } + + async fn create_test_tls_config() -> Arc { + let certs = load_certs("examples/sample.pem").unwrap(); + let key = load_private_key("examples/sample.rsa").unwrap(); + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .unwrap(); + Arc::new(config) + } + + async fn send_request( + addr: SocketAddr, + req: Request>, + ) -> Result, Box> { + let stream = TcpStream::connect(addr).await?; + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); + + Ok(sender.send_request(req).await?) + } + + // HTTP Tests + + mod http_tests { + use super::*; + + #[tokio::test] + async fn test_http_basic_requests() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + None, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + // Test GET request + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + + // Test POST request + let req = Request::builder() + .method(hyper::Method::POST) + .uri("/echo") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // Test 404 response + let req = Request::builder() + .uri("/not_found") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn test_http_concurrent_requests() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + None, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let mut handles = vec![]; + for _ in 0..10 { + let addr = server_addr; + let handle = tokio::spawn(async move { + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(addr, req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + }); + handles.push(handle); + } + + for handle in handles { + handle.await.unwrap(); + } + + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn test_http_graceful_shutdown() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + None, + Some(async { + shutdown_rx.await.ok(); + }), + )); + + // Send a request before shutdown + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + let res = send_request(server_addr, req) + .await + .expect("Failed to send initial request"); + assert_eq!(res.status(), StatusCode::OK); + + // Initiate graceful shutdown + shutdown_tx.send(()).unwrap(); + + // Wait for the server to shut down + let shutdown_timeout = Duration::from_millis(150); + let shutdown_result = tokio::time::timeout(shutdown_timeout, async { + loop { + tokio::time::sleep(Duration::from_millis(10)).await; + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + match send_request(server_addr, req).await { + Ok(_) => continue, // Server still accepting connections + Err(e) if e.to_string().contains("Connection refused") => { + // Server has shut down as expected + return Ok(()); + } + Err(e) => return Err(e), // Unexpected error + } + } + }) + .await; + + match shutdown_result { + Ok(Ok(())) => println!("Server shut down successfully"), + Ok(Err(e)) => panic!("Unexpected error during shutdown: {}", e), + Err(_) => panic!("Timeout waiting for server to shut down"), + } + + // Ensure the server task completes + server.await.unwrap().unwrap(); + } + } + + // HTTPS Tests + + mod https_tests { + use super::*; + + async fn create_https_client() -> ( + tokio_rustls::TlsConnector, + rustls::pki_types::ServerName<'static>, + ) { + let mut root_cert_store = rustls::RootCertStore::empty(); + root_cert_store.add_parsable_certificates(load_certs("examples/sample.pem").unwrap()); + + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + + let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + let domain = rustls::pki_types::ServerName::try_from("localhost") + .expect("Failed to create ServerName"); + + (tls_connector, domain) + } + + #[tokio::test] + async fn test_https_connection() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let tls_config = create_test_tls_config().await; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let (tls_connector, domain) = create_https_client().await; + + let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); + let tls_stream = tls_connector.connect(domain, tcp_stream).await.unwrap(); + + let (mut sender, conn) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await + .unwrap(); + + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); + + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let body = res.collect().await.unwrap().to_bytes(); + assert_eq!(&body[..], b"Hello, World!"); + + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn test_https_invalid_client_cert() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let tls_config = create_test_tls_config().await; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + + let tls_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + + let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); + let domain = rustls::pki_types::ServerName::try_from("localhost").unwrap(); + + let result = tls_connector.connect(domain, tcp_stream).await; + assert!( + result.is_err(), + "Expected TLS connection to fail due to invalid client certificate" + ); + + shutdown_tx.send(()).unwrap(); + server.await.unwrap().unwrap(); + } + #[tokio::test] + async fn test_https_graceful_shutdown() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (incoming, server_addr) = setup_test_server(addr).await; + + let tls_config = create_test_tls_config().await; + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let http_server_builder = HttpConnectionBuilder::new(TokioExecutor::new()); + let tower_service_fn = tower::service_fn(echo); + let hyper_service = TowerToHyperService::new(tower_service_fn); + + let server = tokio::spawn(serve_http_with_shutdown( + hyper_service, + incoming, + http_server_builder, + Some(tls_config), + Some(async { + shutdown_rx.await.ok(); + }), + )); + + let (tls_connector, domain) = create_https_client().await; + + // Establish a connection + let tcp_stream = TcpStream::connect(server_addr).await.unwrap(); + let tls_stream = tls_connector.connect(domain, tcp_stream).await.unwrap(); + + let (mut sender, conn) = + hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await + .unwrap(); + + tokio::spawn(async move { + if let Err(err) = conn.await { + eprintln!("Connection failed: {:?}", err); + } + }); + + // Send a request + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + + let res = sender.send_request(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + // Initiate graceful shutdown + shutdown_tx.send(()).unwrap(); + + // Wait a bit to allow the server to start shutting down + tokio::time::sleep(Duration::from_millis(10)).await; + + // Try to send another request, it should fail + let req = Request::builder() + .uri("/") + .body(Empty::::new()) + .unwrap(); + + let result = sender.send_request(req).await; + assert!( + result.is_err(), + "Expected request to fail after graceful shutdown" + ); + + server.await.unwrap().unwrap(); + } + } +} diff --git a/src/http_config.rs b/src/http_config.rs deleted file mode 100644 index 481a81e..0000000 --- a/src/http_config.rs +++ /dev/null @@ -1,263 +0,0 @@ -use hyper::server::conn::Http; -use std::time::Duration; - -/// Represents a configuration for the [`Http`] protocol. -/// This allows for detailed customization of various HTTP/1 and HTTP/2 settings. -#[derive(Debug, Clone)] -pub struct HttpConfig { - /// The inner HTTP configuration from the `hyper` crate. - pub(crate) inner: Http, -} - -impl Default for HttpConfig { - /// Provides a default HTTP configuration. - fn default() -> Self { - Self::new() - } -} - -impl HttpConfig { - /// Creates a new `HttpConfig` with default settings. - pub fn new() -> HttpConfig { - Self { inner: Http::new() } - } - - /// Clones the current configuration state and returns it. - /// Useful for building configurations dynamically. - pub fn build(&mut self) -> Self { - self.clone() - } - - /// Configures whether to exclusively support HTTP/1. - /// - /// When enabled, only HTTP/1 requests are processed, and HTTP/2 requests are rejected. - /// - /// Default is `false`. - pub fn http1_only(&mut self, val: bool) -> &mut Self { - self.inner.http1_only(val); - self - } - - /// Specifies if HTTP/1 connections should be allowed to use half-closures. - /// - /// A half-closure in TCP occurs when one side of the data stream is terminated, - /// but the other side remains open. This setting, when `true`, ensures the server - /// doesn't immediately close a connection if a client shuts down their sending side - /// while waiting for a response. - /// - /// Default is `false`. - pub fn http1_half_close(&mut self, val: bool) -> &mut Self { - self.inner.http1_half_close(val); - self - } - - /// Enables or disables the keep-alive feature for HTTP/1 connections. - /// - /// Keep-alive allows the connection to be reused for multiple requests and responses. - /// - /// Default is true. - pub fn http1_keep_alive(&mut self, val: bool) -> &mut Self { - self.inner.http1_keep_alive(val); - self - } - - /// Determines if HTTP/1 connections should write headers with title-case naming. - /// - /// For example, turning this setting `true` would send headers as "Content-Type" instead of "content-type". - /// Note that this has no effect on HTTP/2 connections. - /// - /// Default is false. - pub fn http1_title_case_headers(&mut self, enabled: bool) -> &mut Self { - self.inner.http1_title_case_headers(enabled); - self - } - - /// Determines if HTTP/1 connections should preserve the original case of headers. - /// - /// By default, headers might be normalized. Enabling this will ensure headers retain their original casing. - /// This setting doesn't influence HTTP/2. - /// - /// Default is false. - pub fn http1_preserve_header_case(&mut self, enabled: bool) -> &mut Self { - self.inner.http1_preserve_header_case(enabled); - self - } - - /// Configures a timeout for how long the server will wait for client headers. - /// - /// If the client doesn't send all headers within this duration, the connection is terminated. - /// - /// Default is None, meaning no timeout. - pub fn http1_header_read_timeout(&mut self, val: Duration) -> &mut Self { - self.inner.http1_header_read_timeout(val); - self - } - - /// Specifies whether to use vectored writes for HTTP/1 connections. - /// - /// Vectored writes can be efficient for multiple non-contiguous data segments. - /// However, certain transports (like many TLS implementations) may not handle vectored writes well. - /// When disabled, data is flattened into a single buffer before writing. - /// - /// Default is `auto`, where the best method is determined dynamically. - pub fn http1_writev(&mut self, val: bool) -> &mut Self { - self.inner.http1_writev(val); - self - } - - /// Configures the server to exclusively support HTTP/2. - /// - /// When enabled, only HTTP/2 requests are processed, and HTTP/1 requests are rejected. - /// - /// Default is false. - pub fn http2_only(&mut self, val: bool) -> &mut Self { - self.inner.http2_only(val); - self - } - - /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2 - /// stream-level flow control. - /// - /// Passing `None` will do nothing. - /// - /// If not set, hyper will use a default. - /// - /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE - pub fn http2_initial_stream_window_size(&mut self, sz: impl Into>) -> &mut Self { - self.inner.http2_initial_stream_window_size(sz); - self - } - - /// Sets the max connection-level flow control for HTTP2. - /// - /// Passing `None` will do nothing. - /// - /// If not set, hyper will use a default. - pub fn http2_initial_connection_window_size( - &mut self, - sz: impl Into>, - ) -> &mut Self { - self.inner.http2_initial_connection_window_size(sz); - self - } - - /// Sets whether to use an adaptive flow control. - /// - /// Enabling this will override the limits set in - /// `http2_initial_stream_window_size` and - /// `http2_initial_connection_window_size`. - pub fn http2_adaptive_window(&mut self, enabled: bool) -> &mut Self { - self.inner.http2_adaptive_window(enabled); - self - } - - /// Enables the [extended CONNECT protocol]. - /// - /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 - pub fn http2_enable_connect_protocol(&mut self) -> &mut Self { - self.inner.http2_enable_connect_protocol(); - self - } - - /// Sets the maximum frame size to use for HTTP2. - /// - /// Passing `None` will do nothing. - /// - /// If not set, hyper will use a default. - pub fn http2_max_frame_size(&mut self, sz: impl Into>) -> &mut Self { - self.inner.http2_max_frame_size(sz); - self - } - - /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2 - /// connections. - /// - /// Default is no limit (`std::u32::MAX`). Passing `None` will do nothing. - /// - /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS - pub fn http2_max_concurrent_streams(&mut self, max: impl Into>) -> &mut Self { - self.inner.http2_max_concurrent_streams(max); - self - } - - /// Sets the max size of received header frames. - /// - /// Default is currently ~16MB, but may change. - pub fn http2_max_header_list_size(&mut self, max: u32) -> &mut Self { - self.inner.http2_max_header_list_size(max); - self - } - - /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent. - /// - /// This will default to the default value set by the [`h2` crate](https://crates.io/crates/h2). - /// As of v0.3.17, it is 20. - /// - /// See for more information. - pub fn http2_max_pending_accept_reset_streams( - &mut self, - max: impl Into>, - ) -> &mut Self { - self.inner.http2_max_pending_accept_reset_streams(max); - self - } - - /// Set the maximum write buffer size for each HTTP/2 stream. - /// - /// Default is currently ~400KB, but may change. - /// - /// # Panics - /// - /// The value must be no larger than `u32::MAX`. - pub fn http2_max_send_buf_size(&mut self, max: usize) -> &mut Self { - self.inner.http2_max_send_buf_size(max); - self - } - - /// Sets an interval for HTTP2 Ping frames should be sent to keep a - /// connection alive. - /// - /// Pass `None` to disable HTTP2 keep-alive. - /// - /// Default is currently disabled. - pub fn http2_keep_alive_interval( - &mut self, - interval: impl Into>, - ) -> &mut Self { - self.inner.http2_keep_alive_interval(interval); - self - } - - /// Sets a timeout for receiving an acknowledgement of the keep-alive ping. - /// - /// If the ping is not acknowledged within the timeout, the connection will - /// be closed. Does nothing if `http2_keep_alive_interval` is disabled. - /// - /// Default is 20 seconds. - pub fn http2_keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self { - self.inner.http2_keep_alive_timeout(timeout); - self - } - - /// Set the maximum buffer size for the HTTP/1 connection. - /// - /// Default is ~400kb. - /// - /// # Panics - /// - /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum. - pub fn max_buf_size(&mut self, max: usize) -> &mut Self { - self.inner.max_buf_size(max); - self - } - - /// Determines if multiple responses should be buffered and sent together to support pipelined responses. - /// - /// This can improve throughput in certain situations, but is experimental and might contain issues. - /// - /// Default is false. - pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self { - self.inner.pipeline_flush(enabled); - self - } -} diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..feacfa5 --- /dev/null +++ b/src/io.rs @@ -0,0 +1,90 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::server::TlsStream; + +pub(crate) enum Transport { + Plain(IO), + Tls(Box>), +} + +impl Transport { + pub(crate) fn new_plain(io: IO) -> Self { + Self::Plain(io) + } + + pub(crate) fn new_tls(io: TlsStream) -> Self { + Self::Tls(Box::from(io)) + } +} + +impl AsyncRead for Transport +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_read(cx, buf), + Transport::Tls(io) => Pin::new(io).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for Transport +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_write(cx, buf), + Transport::Tls(io) => Pin::new(io).poll_write(cx, buf), + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_flush(cx), + Transport::Tls(io) => Pin::new(io).poll_flush(cx), + } + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_shutdown(cx), + Transport::Tls(io) => Pin::new(io).poll_shutdown(cx), + } + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Transport::Plain(io) => Pin::new(io).poll_write_vectored(cx, bufs), + Transport::Tls(io) => Pin::new(io).poll_write_vectored(cx, bufs), + } + } + + #[inline] + fn is_write_vectored(&self) -> bool { + match self { + Transport::Plain(io) => io.is_write_vectored(), + Transport::Tls(io) => io.is_write_vectored(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 932c092..569a2ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,130 +1,23 @@ -//! hyper-server is a [hyper] server implementation designed to be used with [axum] framework. -//! -//! # Features -//! -//! - HTTP/1 and HTTP/2 -//! - HTTPS through [rustls] or [openssl]. -//! - High performance through [hyper]. -//! - Using [tower] make service API. -//! - Very good [axum] compatibility. Likely to work with future [axum] releases. -//! -//! # Guide -//! -//! hyper-server can [`serve`] items that implement [`MakeService`] with some additional [trait -//! bounds](crate::service::MakeServiceRef). Make services that are [created] using [`axum`] -//! complies with those trait bounds out of the box. Therefore it is more convenient to use this -//! crate with [`axum`]. -//! -//! All examples in this crate uses [`axum`]. If you want to use this crate without [`axum`] it is -//! highly recommended to learn how [tower] works. -//! -//! [`Server::bind`] or [`bind`] function can be called to create a server that will bind to -//! provided [`SocketAddr`] when [`serve`] is called. -//! -//! A [`Handle`] can be passed to [`Server`](Server::handle) for additional utilities like shutdown -//! and graceful shutdown. -//! -//! [`bind_rustls`] can be called by providing [`RustlsConfig`] to create a HTTPS [`Server`] that -//! will bind on provided [`SocketAddr`]. [`RustlsConfig`] can be cloned, reload methods can be -//! used on clone to reload tls configuration. -//! -//! # Features -//! -//! * `tls-rustls` - activate [rustls] support. -//! * `tls-openssl` - activate [openssl] support. -//! -//! # Example -//! -//! A simple hello world application can be served like: -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use std::net::SocketAddr; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! hyper_server::bind(addr) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` -//! -//! You can find more examples in [repository]. -//! -//! [axum]: https://crates.io/crates/axum -//! [bind]: crate::bind -//! [bind_rustls]: crate::bind_rustls -//! [created]: https://docs.rs/axum/0.3/axum/struct.Router.html#method.into_make_service -//! [hyper]: https://crates.io/crates/hyper -//! [openssl]: https://crates.io/crates/openssl -//! [repository]: https://github.com/valorem-labs-inc/hyper-server/examples -//! [rustls]: https://crates.io/crates/rustls -//! [tower]: https://crates.io/crates/tower -//! [`axum`]: https://docs.rs/axum/0.3 -//! [`serve`]: crate::server::Server::serve -//! [`MakeService`]: https://docs.rs/tower/0.4/tower/make/trait.MakeService.html -//! [`RustlsConfig`]: crate::tls_rustls::RustlsConfig -//! [`SocketAddr`]: std::net::SocketAddr - -#![forbid(unsafe_code)] -#![warn( - clippy::await_holding_lock, - clippy::cargo_common_metadata, - clippy::dbg_macro, - clippy::doc_markdown, - clippy::empty_enum, - clippy::enum_glob_use, - clippy::inefficient_to_string, - clippy::mem_forget, - clippy::mutex_integer, - clippy::needless_continue, - clippy::todo, - clippy::unimplemented, - clippy::wildcard_imports, - future_incompatible, - missing_docs, - missing_debug_implementations, - unreachable_pub -)] -#![cfg_attr(docsrs, feature(doc_cfg))] - -mod addr_incoming_config; -mod handle; -mod http_config; -mod notify_once; -mod server; - -pub mod accept; -pub mod service; - -pub use self::{ - addr_incoming_config::AddrIncomingConfig, - handle::Handle, - http_config::HttpConfig, - server::{bind, from_tcp, Server}, -}; - -#[cfg(feature = "tls-rustls")] -#[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))] -pub mod tls_rustls; - -#[doc(inline)] -#[cfg(feature = "tls-rustls")] -pub use self::tls_rustls::export::{bind_rustls, from_tcp_rustls}; - -#[cfg(feature = "tls-openssl")] -#[cfg_attr(docsrs, doc(cfg(feature = "tls-openssl")))] -pub mod tls_openssl; - -#[doc(inline)] -#[cfg(feature = "tls-openssl")] -pub use self::tls_openssl::bind_openssl; - -#[cfg(feature = "proxy-protocol")] -#[cfg_attr(docsrs, doc(cfg(feature = "proxy_protocol")))] -pub mod proxy_protocol; +#[cfg(all(feature = "jemalloc", not(target_env = "msvc")))] +use tikv_jemallocator::Jemalloc; + +#[cfg(all(feature = "jemalloc", not(target_env = "msvc")))] +#[global_allocator] +static GLOBAL: Jemalloc = Jemalloc; + +pub use error::{Error as TransportError, Kind as TransportErrorKind}; +pub use http::serve_http_connection; +pub use http::serve_http_with_shutdown; +pub use tcp::serve_tcp_incoming; +pub use tls::load_certs; +pub use tls::load_private_key; +pub use tls::serve_tls_incoming; + +mod error; +mod fuse; +mod http; +mod io; +mod tcp; +mod tls; + +pub(crate) type Error = Box; diff --git a/src/notify_once.rs b/src/notify_once.rs deleted file mode 100644 index e4564cd..0000000 --- a/src/notify_once.rs +++ /dev/null @@ -1,45 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use tokio::sync::Notify; - -/// A thread-safe utility that provides a one-time notification to waiters. -/// It utilizes an atomic boolean to ensure that the notification is sent only once. -#[derive(Debug, Default)] -pub(crate) struct NotifyOnce { - /// A flag indicating whether a notification has been sent. - notified: AtomicBool, - /// An asynchronous primitive from the Tokio library used for notifying tasks. - notify: Notify, -} - -impl NotifyOnce { - /// Notifies all waiting tasks, ensuring that the notification happens only once. - pub(crate) fn notify_waiters(&self) { - // Atomically set the `notified` flag to true. - self.notified.store(true, Ordering::SeqCst); - - // Notify all waiting tasks. - self.notify.notify_waiters(); - } - - /// Checks whether a notification has been sent. - /// - /// Returns: - /// - `true` if the notification has already been sent. - /// - `false` otherwise. - pub(crate) fn is_notified(&self) -> bool { - self.notified.load(Ordering::SeqCst) - } - - /// Awaits until a notification has been sent. - /// - /// This asynchronous function will immediately complete if a notification - /// has already been sent, otherwise it will await until it's notified. - pub(crate) async fn notified(&self) { - let future = self.notify.notified(); - - // If not notified, await on the future. - if !self.notified.load(Ordering::SeqCst) { - future.await; - } - } -} diff --git a/src/proxy_protocol/future.rs b/src/proxy_protocol/future.rs deleted file mode 100644 index bd9ab26..0000000 --- a/src/proxy_protocol/future.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! Future types for PROXY protocol support. -use crate::accept::Accept; -use crate::proxy_protocol::ForwardClientIp; -use pin_project_lite::pin_project; -use std::{ - fmt, - future::Future, - io, - net::SocketAddr, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::time::Timeout; - -// A `pin_project` is a procedural macro used for safe field projection in conjunction -// with the Rust Pin API, which guarantees that certain types will not move in memory. -pin_project! { - /// This struct represents the future for the ProxyProtocolAcceptor. - /// The generic types are: - /// F: The future type. - /// A: The type that implements the Accept trait. - /// I: The IO type that supports both AsyncRead and AsyncWrite. - /// S: The service type. - pub struct ProxyProtocolAcceptorFuture - where - A: Accept, - { - #[pin] - inner: AcceptFuture, - } -} - -impl ProxyProtocolAcceptorFuture -where - A: Accept, - I: AsyncRead + AsyncWrite + Unpin, -{ - // Constructor for creating a new ProxyProtocolAcceptorFuture. - pub(crate) fn new(future: Timeout, acceptor: A, service: S) -> Self { - let inner = AcceptFuture::ReadHeader { - future, - acceptor, - service: Some(service), - }; - Self { inner } - } -} - -// Implement Debug trait for ProxyProtocolAcceptorFuture to allow -// debugging and logging. -impl fmt::Debug for ProxyProtocolAcceptorFuture -where - A: Accept, - I: AsyncRead + AsyncWrite + Unpin, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ProxyProtocolAcceptorFuture").finish() - } -} - -pin_project! { - // AcceptFuture represents the internal states of ProxyProtocolAcceptorFuture. - // It can either be waiting to read the header or forward the client IP. - #[project = AcceptFutureProj] - enum AcceptFuture - where - A: Accept, - { - ReadHeader { - #[pin] - future: Timeout, - acceptor: A, - service: Option, - }, - ForwardIp { - #[pin] - future: A::Future, - client_address: Option, - }, - } -} - -impl Future for ProxyProtocolAcceptorFuture -where - A: Accept, - I: AsyncRead + AsyncWrite + Unpin, - // Future whose output is a result with either a tuple of stream and optional address, - // or an io::Error. - F: Future), io::Error>>, -{ - type Output = io::Result<(A::Stream, ForwardClientIp)>; - - // The main poll function that drives the future towards completion. - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - loop { - // Check the current state of the inner future. - match this.inner.as_mut().project() { - AcceptFutureProj::ReadHeader { - future, - acceptor, - service, - } => match future.poll(cx) { - Poll::Ready(Ok(Ok((stream, client_address)))) => { - let service = service.take().expect("future polled after ready"); - let future = acceptor.accept(stream, service); - - // Transition to the ForwardIp state after successfully reading the header. - this.inner.set(AcceptFuture::ForwardIp { - future, - client_address, - }); - } - Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(Err(timeout)) => { - return Poll::Ready(Err(io::Error::new(io::ErrorKind::TimedOut, timeout))) - } - Poll::Pending => return Poll::Pending, - }, - AcceptFutureProj::ForwardIp { - future, - client_address, - } => { - return match future.poll(cx) { - Poll::Ready(Ok((stream, service))) => { - let service = ForwardClientIp { - inner: service, - client_address: *client_address, - }; - - // Return the successfully processed stream and service. - Poll::Ready(Ok((stream, service))) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - }; - } - } - } - } -} diff --git a/src/proxy_protocol/mod.rs b/src/proxy_protocol/mod.rs deleted file mode 100644 index 0906eed..0000000 --- a/src/proxy_protocol/mod.rs +++ /dev/null @@ -1,841 +0,0 @@ -//! This feature allows the `hyper_server` to be used behind a layer 4 load balancer whilst the proxy -//! protocol is enabled to preserve the client IP address and port. -//! See The PROXY protocol spec for more details: . -//! -//! Any client address found in the proxy protocol header is forwarded on in the HTTP `forwarded` -//! header to be accessible by the rest server. -//! -//! Note: if you are setting a custom acceptor, `enable_proxy_protocol` must be called after this is set. -//! It is best to use directly before calling `serve` when the inner acceptor is already configured. -//! `ProxyProtocolAcceptor` wraps the initial acceptor, so the proxy header is removed from the -//! beginning of the stream before the messages are forwarded on. -//! -//! # Example -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use std::net::SocketAddr; -//! use std::time::Duration; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! -//! // Can configure if you want different from the default of 5 seconds, -//! // otherwise passing `None` will use the default. -//! let proxy_header_timeout = Some(Duration::from_secs(2)); -//! -//! hyper_server::bind(addr) -//! .enable_proxy_protocol(proxy_header_timeout) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` -use crate::accept::Accept; -use std::{ - fmt, - future::Future, - io, - net::{IpAddr, SocketAddr}, - pin::Pin, - task::{Context, Poll}, - time::Duration, -}; - -use http::HeaderValue; -use http::Request; -use ppp::{v1, v2, HeaderResult}; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite}, - time::timeout, -}; -use tower_service::Service; - -pub(crate) mod future; -use self::future::ProxyProtocolAcceptorFuture; - -/// The length of a v1 header in bytes. -const V1_PREFIX_LEN: usize = 5; -/// The maximum length of a v1 header in bytes. -const V1_MAX_LENGTH: usize = 107; -/// The terminator bytes of a v1 header. -const V1_TERMINATOR: &[u8] = b"\r\n"; -/// The prefix length of a v2 header in bytes. -const V2_PREFIX_LEN: usize = 12; -/// The minimum length of a v2 header in bytes. -const V2_MINIMUM_LEN: usize = 16; -/// The index of the start of the big-endian u16 length in the v2 header. -const V2_LENGTH_INDEX: usize = 14; -/// The length of the read buffer used to read the PROXY protocol header. -const READ_BUFFER_LEN: usize = 512; - -pub(crate) async fn read_proxy_header( - mut stream: I, -) -> Result<(I, Option), io::Error> -where - I: AsyncRead + Unpin, -{ - // Mutable buffer for storing stream data - let mut buffer = [0; READ_BUFFER_LEN]; - // Dynamic in case v2 header is too long - let mut dynamic_buffer = None; - - // Read prefix to check for v1, v2, or kill - stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?; - - if &buffer[..V1_PREFIX_LEN] == v1::PROTOCOL_PREFIX.as_bytes() { - read_v1_header(&mut stream, &mut buffer).await?; - } else { - stream - .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]) - .await?; - if &buffer[..V2_PREFIX_LEN] == v2::PROTOCOL_PREFIX { - dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid Proxy Protocol header detected", - )); - } - } - - // Choose which buffer to parse - let buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]); - - // Parse the header - let header = HeaderResult::parse(buffer); - match header { - HeaderResult::V1(Ok(header)) => { - let client_address = match header.addresses { - v1::Addresses::Tcp4(ip) => { - SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port) - } - v1::Addresses::Tcp6(ip) => { - SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port) - } - v1::Addresses::Unknown => { - // Return client address as `None` so that "unknown" is used in the http header - return Ok((stream, None)); - } - }; - - Ok((stream, Some(client_address))) - } - HeaderResult::V2(Ok(header)) => { - let client_address = match header.addresses { - v2::Addresses::IPv4(ip) => { - SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port) - } - v2::Addresses::IPv6(ip) => { - SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port) - } - v2::Addresses::Unix(unix) => { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "Unix socket addresses are not supported. Addresses: {:?}", - unix - ), - )); - } - v2::Addresses::Unspecified => { - // Return client address as `None` so that "unknown" is used in the http header - return Ok((stream, None)); - } - }; - - Ok((stream, Some(client_address))) - } - HeaderResult::V1(Err(_error)) => Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid V1 Proxy Protocol header received", - )), - HeaderResult::V2(Err(_error)) => Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid V2 Proxy Protocol header received", - )), - } -} - -async fn read_v2_header( - mut stream: I, - buffer: &mut [u8; READ_BUFFER_LEN], -) -> Result>, io::Error> -where - I: AsyncRead + Unpin, -{ - let length = - u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize; - let full_length = V2_MINIMUM_LEN + length; - - // Switch to dynamic buffer if header is too long; v2 has no maximum length - if full_length > READ_BUFFER_LEN { - let mut dynamic_buffer = Vec::with_capacity(full_length); - dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]); - - // Read the remaining header length - stream - .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length]) - .await?; - - Ok(Some(dynamic_buffer)) - } else { - // Read the remaining header length - stream - .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]) - .await?; - - Ok(None) - } -} - -async fn read_v1_header( - mut stream: I, - buffer: &mut [u8; READ_BUFFER_LEN], -) -> Result<(), io::Error> -where - I: AsyncRead + Unpin, -{ - // read one byte at a time until terminator found - let mut end_found = false; - for i in V1_PREFIX_LEN..V1_MAX_LENGTH { - buffer[i] = stream.read_u8().await?; - - if [buffer[i - 1], buffer[i]] == V1_TERMINATOR { - end_found = true; - break; - } - } - if !end_found { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "No valid Proxy Protocol header detected", - )); - } - - Ok(()) -} - -/// Middleware for adding client IP address to the request `forwarded` header. -/// see spec: -#[derive(Debug, Clone)] -pub struct ForwardClientIp { - inner: S, - client_address: Option, -} - -impl Service> for ForwardClientIp -where - S: Service>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - // The full socket address is available in the proxy header, hence why we include port - let mut forwarded_string = match self.client_address { - Some(socket_addr) => match socket_addr { - SocketAddr::V4(addr) => { - format!("for={}:{}", addr.ip(), addr.port()) - } - SocketAddr::V6(addr) => { - format!("for=\"[{}]:{}\"", addr.ip(), addr.port()) - } - }, - None => "for=unknown".to_string(), - }; - - if let Some(existing_value) = req.headers_mut().get("Forwarded") { - forwarded_string = format!( - "{}, {}", - existing_value.to_str().unwrap_or(""), - forwarded_string - ); - } - - if let Ok(header_value) = HeaderValue::from_str(&forwarded_string) { - req.headers_mut().insert("Forwarded", header_value); - } - - self.inner.call(req) - } -} - -/// Acceptor wrapper for receiving Proxy Protocol headers. -#[derive(Clone)] -pub struct ProxyProtocolAcceptor { - inner: A, - parsing_timeout: Duration, -} - -impl ProxyProtocolAcceptor { - /// Create a new proxy protocol acceptor from an initial acceptor. - /// This is compatible with tls acceptors. - pub fn new(inner: A) -> Self { - #[cfg(not(test))] - let parsing_timeout = Duration::from_secs(5); - - // Don't force tests to wait too long. - #[cfg(test)] - let parsing_timeout = Duration::from_secs(1); - - Self { - inner, - parsing_timeout, - } - } - - /// Override the default Proxy Header parsing timeout. - pub fn parsing_timeout(mut self, val: Duration) -> Self { - self.parsing_timeout = val; - self - } -} - -impl ProxyProtocolAcceptor { - /// Overwrite inner acceptor. - pub fn acceptor(self, acceptor: Acceptor) -> ProxyProtocolAcceptor { - ProxyProtocolAcceptor { - inner: acceptor, - parsing_timeout: self.parsing_timeout, - } - } -} - -impl Accept for ProxyProtocolAcceptor -where - A: Accept + Clone, - A::Stream: AsyncRead + AsyncWrite + Unpin, - I: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Stream = A::Stream; - type Service = ForwardClientIp; - type Future = ProxyProtocolAcceptorFuture< - Pin), io::Error>> + Send>>, - A, - I, - S, - >; - - fn accept(&self, stream: I, service: S) -> Self::Future { - let future = Box::pin(read_proxy_header(stream)); - - ProxyProtocolAcceptorFuture::new( - timeout(self.parsing_timeout, future), - self.inner.clone(), - service, - ) - } -} - -impl fmt::Debug for ProxyProtocolAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ProxyProtocolAcceptor").finish() - } -} - -#[cfg(test)] -mod tests { - #[cfg(feature = "tls-openssl")] - use crate::tls_openssl::{ - self, - tests::{dns_name as openssl_dns_name, tls_connector as openssl_connector}, - OpenSSLConfig, - }; - #[cfg(feature = "tls-rustls")] - use crate::tls_rustls::{ - self, - tests::{dns_name as rustls_dns_name, tls_connector as rustls_connector}, - RustlsConfig, - }; - use crate::{handle::Handle, server::Server}; - use axum::http::Response; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use ppp::v2::{Builder, Command, Protocol, Type, Version}; - use std::{io, net::SocketAddr, time::Duration}; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use tokio::{ - net::{TcpListener, TcpStream}, - task::JoinHandle, - time::timeout, - }; - use tower::{Service, ServiceExt}; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, _client_addr) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn server_receives_client_address() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn server_receives_client_address_v1() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::V1) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[cfg(feature = "tls-rustls")] - #[tokio::test] - async fn rustls_server_receives_client_address() { - let (_handle, _server_task, server_addr) = start_rustls_server().await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = rustls_connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[cfg(feature = "tls-openssl")] - #[tokio::test] - async fn openssl_server_receives_client_address() { - let (_handle, _server_task, server_addr) = start_openssl_server().await; - - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, client_addr) = openssl_connect(addr).await; - - let (parts, body) = send_empty_request(&mut client).await; - - // Check for the Forwarded header - let forwarded_header = parts - .headers - .get("Forwarded") - .expect("No Forwarded header present") - .to_str() - .expect("Failed to convert Forwarded header to str"); - - assert!(forwarded_header.contains(&format!("for={}", client_addr))); - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn not_parsing_when_header_present_fails() { - // Start the server with proxy protocol disabled - let (_handle, _server_task, server_addr) = start_server(false).await; - - // Start the proxy - let addr = start_proxy(server_addr, ProxyVersion::V2) - .await - .expect("Failed to start proxy"); - - // Connect to the proxy - let (mut client, _conn, _client_addr) = connect(addr).await; - - // Send a request to the proxy - match client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - { - // TODO(This should fail when there is no proxy protocol support, perhaps) - Ok(_o) => { - //dbg!(_o); - //() - } - Err(e) => { - if e.is_incomplete_message() { - } else { - panic!("Received unexpected error"); - } - } - } - } - - #[tokio::test] - async fn parsing_when_header_not_present_fails() { - let (_handle, _server_task, server_addr) = start_server(true).await; - - let addr = start_proxy(server_addr, ProxyVersion::None) - .await - .expect("Failed to start proxy"); - - let (mut client, _conn, _client_addr) = connect(addr).await; - - match client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - { - Ok(_) => panic!("Should have failed"), - Err(e) => { - if e.is_incomplete_message() { - } else { - panic!("Received unexpected error"); - } - } - } - } - - async fn forward_ip_handler(req: Request) -> Response { - let mut response = Response::new(Body::from("Hello, world!")); - - if let Some(header_value) = req.headers().get("Forwarded") { - response - .headers_mut() - .insert("Forwarded", header_value.clone()); - } - - response - } - - async fn start_server( - parse_proxy_header: bool, - ) -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(forward_ip_handler)); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - if parse_proxy_header { - Server::bind(addr) - .handle(server_handle) - .enable_proxy_protocol(None) - .serve(app.into_make_service()) - .await - } else { - Server::bind(addr) - .handle(server_handle) - .serve(app.into_make_service()) - .await - } - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - #[cfg(feature = "tls-rustls")] - async fn start_rustls_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(forward_ip_handler)); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await?; - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_rustls::bind_rustls(addr, config) - .handle(server_handle) - .enable_proxy_protocol(None) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - #[cfg(feature = "tls-openssl")] - async fn start_openssl_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(forward_ip_handler)); - - let config = OpenSSLConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_openssl::bind_openssl(addr, config) - .handle(server_handle) - .enable_proxy_protocol(None) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - #[derive(Debug, Clone, Copy)] - enum ProxyVersion { - V1, - V2, - None, - } - - async fn start_proxy( - server_address: SocketAddr, - proxy_version: ProxyVersion, - ) -> Result> { - let proxy_address = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(proxy_address).await?; - let proxy_address = listener.local_addr()?; - - let _proxy_task = tokio::spawn(async move { - loop { - match listener.accept().await { - Ok((client_stream, _)) => { - tokio::spawn(async move { - if let Err(e) = - handle_conn(client_stream, server_address, proxy_version).await - { - println!("Error handling connection: {:?}", e); - } - }); - } - Err(e) => println!("Failed to accept a connection: {:?}", e), - } - } - }); - - Ok(proxy_address) - } - - async fn handle_conn( - mut client_stream: TcpStream, - server_address: SocketAddr, - proxy_version: ProxyVersion, - ) -> io::Result<()> { - let client_address = client_stream.peer_addr()?; // Get the address before splitting - let mut server_stream = TcpStream::connect(server_address).await?; - let server_address = server_stream.peer_addr()?; // Get the address before splitting - - let (mut client_read, mut client_write) = client_stream.split(); - let (mut server_read, mut server_write) = server_stream.split(); - - send_proxy_header( - &mut server_write, - client_address, - server_address, - proxy_version, - ) - .await?; - - let duration = Duration::from_secs(1); - let client_to_server = async { - match timeout(duration, transfer(&mut client_read, &mut server_write)).await { - Ok(result) => result, - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - "Client to Server transfer timed out", - )), - } - }; - - let server_to_client = async { - match timeout(duration, transfer(&mut server_read, &mut client_write)).await { - Ok(result) => result, - Err(_) => Err(io::Error::new( - io::ErrorKind::TimedOut, - "Server to Client transfer timed out", - )), - } - }; - - let _ = tokio::try_join!(client_to_server, server_to_client); - - Ok(()) - } - - async fn transfer( - read_stream: &mut (impl AsyncReadExt + Unpin), - write_stream: &mut (impl AsyncWriteExt + Unpin), - ) -> io::Result<()> { - let mut buf = [0; 4096]; - loop { - let n = read_stream.read(&mut buf).await?; - if n == 0 { - break; // EOF - } - write_stream.write_all(&buf[..n]).await?; - } - Ok(()) - } - - async fn send_proxy_header( - write_stream: &mut (impl AsyncWriteExt + Unpin), - client_address: SocketAddr, - server_address: SocketAddr, - proxy_version: ProxyVersion, - ) -> io::Result<()> { - match proxy_version { - ProxyVersion::V1 => { - let header = ppp::v1::Addresses::from((client_address, server_address)).to_string(); - - for byte in header.as_bytes() { - write_stream.write_all(&[*byte]).await?; - } - } - ProxyVersion::V2 => { - let mut header = Builder::with_addresses( - // Declare header as mutable - Version::Two | Command::Proxy, - Protocol::Stream, - (client_address, server_address), - ) - .write_tlv(Type::NoOp, b"Hello, World!")? - .build()?; - - for byte in header.drain(..) { - write_stream.write_all(&[byte]).await?; - } - } - ProxyVersion::None => {} - } - - Ok(()) - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { - let stream = TcpStream::connect(addr).await.unwrap(); - let client_addr = stream.local_addr().unwrap(); - - let (send_request, connection) = handshake(stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task, client_addr) - } - - #[cfg(feature = "tls-rustls")] - async fn rustls_connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { - let stream = TcpStream::connect(addr).await.unwrap(); - let client_addr = stream.local_addr().unwrap(); - let tls_stream = rustls_connector() - .connect(rustls_dns_name(), stream) - .await - .unwrap(); - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task, client_addr) - } - - #[cfg(feature = "tls-openssl")] - async fn openssl_connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { - let stream = TcpStream::connect(addr).await.unwrap(); - let client_addr = stream.local_addr().unwrap(); - let tls_stream = openssl_connector(openssl_dns_name(), stream).await; - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task, client_addr) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } -} diff --git a/src/server.rs b/src/server.rs deleted file mode 100644 index 6db138a..0000000 --- a/src/server.rs +++ /dev/null @@ -1,523 +0,0 @@ -#[cfg(feature = "proxy-protocol")] -use crate::proxy_protocol::ProxyProtocolAcceptor; -use crate::{ - accept::{Accept, DefaultAcceptor}, - addr_incoming_config::AddrIncomingConfig, - handle::Handle, - http_config::HttpConfig, - service::{MakeServiceRef, SendService}, -}; -use futures_util::future::poll_fn; -use http::Request; -use hyper::server::{ - accept::Accept as HyperAccept, - conn::{AddrIncoming, AddrStream}, -}; -#[cfg(feature = "proxy-protocol")] -use std::time::Duration; -use std::{ - io::{self, ErrorKind}, - net::SocketAddr, - pin::Pin, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - net::TcpListener, -}; - -/// Represents an HTTP server with customization capabilities for handling incoming requests. -#[derive(Debug)] -pub struct Server { - acceptor: A, - listener: Listener, - addr_incoming_conf: AddrIncomingConfig, - handle: Handle, - http_conf: HttpConfig, - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: bool, -} - -/// Enum representing the ways the server can be initialized - either by binding to an address or from a standard TCP listener. -#[derive(Debug)] -enum Listener { - Bind(SocketAddr), - Std(std::net::TcpListener), -} - -/// Creates a new [`Server`] instance that binds to the provided address. -pub fn bind(addr: SocketAddr) -> Server { - Server::bind(addr) -} - -/// Creates a new [`Server`] instance using an existing `std::net::TcpListener`. -pub fn from_tcp(listener: std::net::TcpListener) -> Server { - Server::from_tcp(listener) -} - -impl Server { - /// Constructs a server bound to the provided address. - pub fn bind(addr: SocketAddr) -> Self { - let acceptor = DefaultAcceptor::new(); - let handle = Handle::new(); - - Self { - acceptor, - listener: Listener::Bind(addr), - addr_incoming_conf: AddrIncomingConfig::default(), - handle, - http_conf: HttpConfig::default(), - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: false, - } - } - - /// Constructs a server from an existing `std::net::TcpListener`. - pub fn from_tcp(listener: std::net::TcpListener) -> Self { - let acceptor = DefaultAcceptor::new(); - let handle = Handle::new(); - - Self { - acceptor, - listener: Listener::Std(listener), - addr_incoming_conf: AddrIncomingConfig::default(), - handle, - http_conf: HttpConfig::default(), - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: false, - } - } -} - -impl Server { - /// Replace the current acceptor with a new one. - pub fn acceptor(self, acceptor: Acceptor) -> Server { - #[cfg(feature = "proxy-protocol")] - if self.proxy_acceptor_set { - panic!("Overwriting the acceptor after proxy protocol is enabled is not supported. Configure the acceptor first in the builder, then enable proxy protocol."); - } - - Server { - acceptor, - listener: self.listener, - addr_incoming_conf: self.addr_incoming_conf, - handle: self.handle, - http_conf: self.http_conf, - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: self.proxy_acceptor_set, - } - } - - #[cfg(feature = "proxy-protocol")] - /// Enable proxy protocol header parsing. - /// Note has to be called after initial acceptor is set. - pub fn enable_proxy_protocol( - self, - parsing_timeout: Option, - ) -> Server> { - let initial_acceptor = self.acceptor; - let mut acceptor = ProxyProtocolAcceptor::new(initial_acceptor); - - if let Some(val) = parsing_timeout { - acceptor = acceptor.parsing_timeout(val); - } - - Server { - acceptor, - listener: self.listener, - addr_incoming_conf: self.addr_incoming_conf, - handle: self.handle, - http_conf: self.http_conf, - proxy_acceptor_set: true, - } - } - - /// Maps the current acceptor to a new type. - pub fn map(self, acceptor: F) -> Server - where - F: FnOnce(A) -> Acceptor, - { - Server { - acceptor: acceptor(self.acceptor), - listener: self.listener, - addr_incoming_conf: self.addr_incoming_conf, - handle: self.handle, - http_conf: self.http_conf, - #[cfg(feature = "proxy-protocol")] - proxy_acceptor_set: self.proxy_acceptor_set, - } - } - - /// Retrieves a reference to the server's acceptor. - pub fn get_ref(&self) -> &A { - &self.acceptor - } - - /// Retrieves a mutable reference to the server's acceptor. - pub fn get_mut(&mut self) -> &mut A { - &mut self.acceptor - } - - /// Provides the server with a handle for extra utilities. - pub fn handle(mut self, handle: Handle) -> Self { - self.handle = handle; - self - } - - /// Replaces the current HTTP configuration. - pub fn http_config(mut self, config: HttpConfig) -> Self { - self.http_conf = config; - self - } - - /// Replaces the current incoming address configuration. - pub fn addr_incoming_config(mut self, config: AddrIncomingConfig) -> Self { - self.addr_incoming_conf = config; - self - } - - /// Serves the provided `MakeService`. - /// - /// The `MakeService` is responsible for constructing services for each incoming connection. - /// Each service is then used to handle requests from that specific connection. - /// - /// # Arguments - /// - `make_service`: A mutable reference to a type implementing the `MakeServiceRef` trait. - /// This will be used to produce a service for each incoming connection. - /// - /// # Errors - /// - /// This method can return errors in the following scenarios: - /// - When binding to an address fails. - /// - If the `make_service` function encounters an error during its `poll_ready` call. - /// It's worth noting that this error scenario doesn't typically occur with `axum` make services. - /// - pub async fn serve(self, mut make_service: M) -> io::Result<()> - where - M: MakeServiceRef>, - A: Accept + Clone + Send + Sync + 'static, - A::Stream: AsyncRead + AsyncWrite + Unpin + Send, - A::Service: SendService> + Send, - A::Future: Send, - { - // Extract relevant fields from `self` for easier access. - let acceptor = self.acceptor; - let addr_incoming_conf = self.addr_incoming_conf; - let handle = self.handle; - let http_conf = self.http_conf; - - // Bind the incoming connections. Notify the handle if an error occurs during binding. - let mut incoming = match bind_incoming(self.listener, addr_incoming_conf).await { - Ok(v) => v, - Err(e) => { - handle.notify_listening(None); - return Err(e); - } - }; - - // Notify the handle about the server's listening state. - handle.notify_listening(Some(incoming.local_addr())); - - // This is the main loop that accepts incoming connections and spawns tasks to handle them. - let accept_loop_future = async { - loop { - // Wait for a new connection or for the server to be signaled to shut down. - let addr_stream = tokio::select! { - biased; - result = accept(&mut incoming) => result?, - _ = handle.wait_graceful_shutdown() => return Ok(()), - }; - - // Ensure the `make_service` is ready to produce another service. - poll_fn(|cx| make_service.poll_ready(cx)) - .await - .map_err(io_other)?; - - // Create a service for this connection. - let service = match make_service.make_service(&addr_stream).await { - Ok(service) => service, - Err(_) => continue, // TODO: Consider logging or handling this error in a more detailed manner. - }; - - // Clone necessary objects for the spawned task. - let acceptor = acceptor.clone(); - let watcher = handle.watcher(); - let http_conf = http_conf.clone(); - - // Spawn a new task to handle the connection. - tokio::spawn(async move { - if let Ok((stream, send_service)) = acceptor.accept(addr_stream, service).await - { - let service = send_service.into_service(); - - let mut serve_future = http_conf - .inner - .serve_connection(stream, service) - .with_upgrades(); - - // Wait for either the server to be shut down or the connection to finish. - tokio::select! { - biased; - _ = watcher.wait_graceful_shutdown() => { - // Initiate a graceful shutdown. - Pin::new(&mut serve_future).graceful_shutdown(); - tokio::select! { - biased; - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), - } - } - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), - } - } - // TODO: Consider logging or handling any errors that occur during acceptance. - }); - } - }; - - // Wait for either the server to be fully shut down or an error to occur. - let result = tokio::select! { - biased; - _ = handle.wait_shutdown() => return Ok(()), - result = accept_loop_future => result, - }; - - // Handle potential errors. - // TODO: Consider removing the Clippy annotation by restructuring this error handling. - #[allow(clippy::question_mark)] - if let Err(e) = result { - return Err(e); - } - - // Wait for all connections to end. - handle.wait_connections_end().await; - - Ok(()) - } -} - -/// Binds the listener based on the provided configuration and returns an [`AddrIncoming`] -/// which will produce [`AddrStream`]s for incoming connections. -/// -/// The function takes into account different ways the listener might be set up, -/// either by binding to a provided address or by using an existing standard listener. -/// -/// # Arguments -/// -/// - `listener`: The listener configuration. Can be either a direct bind address or an existing standard listener. -/// - `addr_incoming_conf`: Configuration for the incoming connections, such as TCP keepalive settings. -/// -/// # Errors -/// -/// Returns an `io::Error` if: -/// - Binding the listener fails. -/// - Setting the listener to non-blocking mode fails. -/// - The listener cannot be converted to a [`TcpListener`]. -/// - An error occurs when creating the [`AddrIncoming`]. -/// -async fn bind_incoming( - listener: Listener, - addr_incoming_conf: AddrIncomingConfig, -) -> io::Result { - let listener = match listener { - Listener::Bind(addr) => TcpListener::bind(addr).await?, - Listener::Std(std_listener) => { - std_listener.set_nonblocking(true)?; - TcpListener::from_std(std_listener)? - } - }; - let mut incoming = AddrIncoming::from_listener(listener).map_err(io_other)?; - - // Apply configuration settings to the incoming connection handler. - incoming.set_sleep_on_errors(addr_incoming_conf.tcp_sleep_on_accept_errors); - incoming.set_keepalive(addr_incoming_conf.tcp_keepalive); - incoming.set_keepalive_interval(addr_incoming_conf.tcp_keepalive_interval); - incoming.set_keepalive_retries(addr_incoming_conf.tcp_keepalive_retries); - incoming.set_nodelay(addr_incoming_conf.tcp_nodelay); - - Ok(incoming) -} - -/// Awaits and accepts a new incoming connection. -/// -/// This function will poll the given `incoming` object until a new connection is ready to be accepted. -/// -/// # Arguments -/// -/// - `incoming`: The incoming connection handler from which new connections will be accepted. -/// -/// # Returns -/// -/// Returns the accepted [`AddrStream`] which represents a specific incoming connection. -/// -/// # Panics -/// -/// This function will panic if the `poll_accept` method returns `None`, which should never happen as per the Hyper documentation. -/// -pub(crate) async fn accept(incoming: &mut AddrIncoming) -> io::Result { - let mut incoming = Pin::new(incoming); - - // Always [`Option::Some`]. - // According to: https://docs.rs/hyper/0.14.14/src/hyper/server/tcp.rs.html#165 - poll_fn(|cx| incoming.as_mut().poll_accept(cx)) - .await - .unwrap() -} - -/// Type definition for a boxed error which can be sent between threads and is Sync. -type BoxError = Box; - -/// Converts any error into an `io::Error` of kind `Other`. -/// -/// This function can be used to create a uniform `io::Error` response for various error types. -/// -/// # Arguments -/// -/// - `error`: The error to be converted. -/// -/// # Returns -/// -/// Returns an `io::Error` with the kind set to `Other` and the provided error as its cause. -/// -pub(crate) fn io_other>(error: E) -> io::Error { - io::Error::new(ErrorKind::Other, error) -} - -#[cfg(test)] -mod tests { - use crate::{handle::Handle, server::Server}; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use std::{io, net::SocketAddr, time::Duration}; - use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; - use tower::{Service, ServiceExt}; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn test_shutdown() { - let (handle, _server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.shutdown(); - - let response_future_result = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await; - - assert!(response_future_result.is_err()); - - // Connection task should finish soon. - let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); - } - - #[tokio::test] - async fn test_graceful_shutdown() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.graceful_shutdown(None); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Disconnect client. - conn.abort(); - - // TODO(This does not shut down gracefully) - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - #[tokio::test] - async fn test_graceful_shutdown_timed() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - handle.graceful_shutdown(Some(Duration::from_millis(250))); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - Server::bind(addr) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { - let stream = TcpStream::connect(addr).await.unwrap(); - - let (send_request, connection) = handshake(stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } -} diff --git a/src/service.rs b/src/service.rs deleted file mode 100644 index ee8848d..0000000 --- a/src/service.rs +++ /dev/null @@ -1,164 +0,0 @@ -//! Module containing service traits. -//! These traits are vital for handling requests and creating services within the server. - -use http::Response; -use http_body::Body; -use std::{ - future::Future, - task::{Context, Poll}, -}; -use tower_service::Service; - -// TODO(Document the types here to disable the clippy annotation) - -/// An alias trait for the [`Service`] trait, specialized with required bounds for the server's service function. -/// This trait has been sealed, ensuring it cannot be implemented by types outside of this crate. -/// -/// It provides constraints for the body data, errors, and asynchronous behavior that fits the server's needs. -#[allow(missing_docs)] -pub trait SendService: send_service::Sealed { - type Service: Service< - Request, - Response = Response, - Error = Self::Error, - Future = Self::Future, - > + Send - + 'static; - - type Body: Body + Send + 'static; - type BodyData: Send + 'static; - type BodyError: Into>; - - type Error: Into>; - type Future: Future, Self::Error>> + Send + 'static; - - /// Convert this type into a service. - fn into_service(self) -> Self::Service; -} - -impl send_service::Sealed for T -where - T: Service>, - T::Error: Into>, - T::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, -{ -} - -impl SendService for T -where - T: Service> + Send + 'static, - T::Error: Into>, - T::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, -{ - type Service = T; - - type Body = B; - type BodyData = B::Data; - type BodyError = B::Error; - - type Error = T::Error; - - type Future = T::Future; - - fn into_service(self) -> Self::Service { - self - } -} - -/// A variant of the [`MakeService`] trait that accepts a `&Target` reference. -/// This trait has been sealed, ensuring it cannot be implemented by types outside of this crate. -/// It is specifically designed for the server's `serve` function. -/// -/// This trait provides a mechanism to create services upon request, with the required trait bounds. -/// -/// [`MakeService`]: https://docs.rs/tower/0.4/tower/make/trait.MakeService.html -#[allow(missing_docs)] -pub trait MakeServiceRef: make_service_ref::Sealed<(Target, Request)> { - type Service: Service< - Request, - Response = Response, - Error = Self::Error, - Future = Self::Future, - > + Send - + 'static; - - type Body: Body + Send + 'static; - type BodyData: Send + 'static; - type BodyError: Into>; - - type Error: Into>; - type Future: Future, Self::Error>> + Send + 'static; - - type MakeError: Into>; - type MakeFuture: Future>; - - /// Polls to check if the service factory is ready to create a service. - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll>; - - /// Creates and returns a service for the provided target. - fn make_service(&mut self, target: &Target) -> Self::MakeFuture; -} - -impl make_service_ref::Sealed<(Target, Request)> for T -where - T: for<'a> Service<&'a Target, Response = S, Error = E, Future = F>, - S: Service> + Send + 'static, - S::Error: Into>, - S::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, - E: Into>, - F: Future>, -{ -} - -impl MakeServiceRef for T -where - T: for<'a> Service<&'a Target, Response = S, Error = E, Future = F>, - S: Service> + Send + 'static, - S::Error: Into>, - S::Future: Send + 'static, - B: Body + Send + 'static, - B::Data: Send + 'static, - B::Error: Into>, - E: Into>, - F: Future>, -{ - type Service = S; - - type Body = B; - type BodyData = B::Data; - type BodyError = B::Error; - - type Error = S::Error; - - type Future = S::Future; - - type MakeError = E; - type MakeFuture = F; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.poll_ready(cx) - } - - fn make_service(&mut self, target: &Target) -> Self::MakeFuture { - self.call(target) - } -} - -// Sealed traits prevent external implementations of our core traits. -// This provides future compatibility guarantees. -mod send_service { - pub trait Sealed {} -} - -mod make_service_ref { - pub trait Sealed {} -} diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..98347d8 --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,217 @@ +use crate::Error; +use std::{io, ops::ControlFlow}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_stream::{Stream, StreamExt}; +use tracing::debug; + +/// Handles errors that occur during TCP connection acceptance. +/// +/// This function determines whether an error should be treated as fatal (breaking the accept loop) +/// or non-fatal (allowing the loop to continue). +/// +/// # Arguments +/// +/// * `e` - The error to handle, which can be converted into the crate's `Error` type. +/// +/// # Returns +/// +/// * `ControlFlow::Continue(())` if the error is non-fatal and the accept loop should continue. +/// * `ControlFlow::Break(Error)` if the error is fatal and the accept loop should terminate. +/// +/// # Error Handling +/// +/// The function categorizes errors as follows: +/// - Non-fatal errors: ConnectionAborted, Interrupted, InvalidData, WouldBlock +/// - Fatal errors: All other error types +fn handle_accept_error(e: impl Into) -> ControlFlow { + let e = e.into(); + debug!(error = %e, "TCP accept loop error"); + + // Check if the error is an I/O error + if let Some(e) = e.downcast_ref::() { + // Determine if the error is non-fatal + if matches!( + e.kind(), + io::ErrorKind::ConnectionAborted + | io::ErrorKind::Interrupted + | io::ErrorKind::InvalidData + | io::ErrorKind::WouldBlock + ) { + return ControlFlow::Continue(()); + } + } + + // If not a non-fatal I/O error, treat as fatal + ControlFlow::Break(e) +} + +/// Creates a stream that yields a TCP stream for each incoming connection. +/// +/// This function takes a stream of incoming connections and handles errors that may occur +/// during the acceptance process. It will continue to yield connections even if non-fatal +/// errors occur, but will terminate if a fatal error is encountered. +/// +/// # Type Parameters +/// +/// * `IO`: The type of the I/O object yielded by the incoming stream. +/// * `IE`: The type of the error that can be produced by the incoming stream. +/// +/// # Arguments +/// +/// * `incoming`: A stream that yields results of incoming connection attempts. +/// +/// # Returns +/// +/// A stream that yields `Result` for each incoming connection. +/// +/// # Error Handling +/// +/// This function uses `handle_accept_error` to determine whether to continue accepting +/// connections after an error occurs. Non-fatal errors are logged and skipped, while +/// fatal errors cause the stream to yield an error and terminate. +#[inline] +pub fn serve_tcp_incoming( + incoming: impl Stream> + Send + 'static, +) -> impl Stream> +where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into + Send + 'static, +{ + async_stream::stream! { + let mut incoming = Box::pin(incoming); + + while let Some(item) = incoming.next().await { + match item { + Ok(io) => yield Ok(io), + Err(e) => match handle_accept_error(e.into()) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(e) => yield Err(e), + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::SocketAddr; + use std::pin::Pin; + use tokio::net::{TcpListener, TcpStream}; + use tokio_stream::wrappers::TcpListenerStream; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn test_handle_accept_error() { + // Test non-fatal errors + let non_fatal_errors = vec![ + io::ErrorKind::ConnectionAborted, + io::ErrorKind::Interrupted, + io::ErrorKind::InvalidData, + io::ErrorKind::WouldBlock, + ]; + + for kind in non_fatal_errors { + let error = io::Error::new(kind, "Test error"); + assert!(matches!( + handle_accept_error(error), + ControlFlow::Continue(()) + )); + } + + // Test fatal error + let fatal_error = io::Error::new(io::ErrorKind::PermissionDenied, "Permission denied"); + assert!(matches!( + handle_accept_error(fatal_error), + ControlFlow::Break(_) + )); + } + + #[tokio::test] + async fn test_serve_tcp_incoming_success() -> Result<(), Box> { + // Set up a TCP listener + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let bound_addr = listener.local_addr()?; + let stream = TcpListenerStream::new(listener); + let mut incoming = Box::pin(serve_tcp_incoming(stream)); + + // Spawn a task to accept one connection + let accept_task = tokio::spawn(async move { incoming.next().await }); + + // Connect to the server + let _client = TcpStream::connect(bound_addr).await?; + + // Check that the connection was accepted + let result = accept_task.await?.unwrap(); + assert!(result.is_ok()); + + Ok(()) + } + + #[tokio::test] + async fn test_serve_tcp_incoming_with_errors() { + // Create a mock stream that yields both successful connections and errors + let mock_stream = tokio_stream::iter(vec![ + Ok(MockIO), + Err(io::Error::new(io::ErrorKind::ConnectionAborted, "Aborted")), + Ok(MockIO), + Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Permission denied", + )), + ]); + + let mut incoming = Box::pin(serve_tcp_incoming(mock_stream)); + + // First connection should be successful + assert!(incoming.next().await.unwrap().is_ok()); + + // Second connection (aborted) should be skipped + // Third connection should be successful + assert!(incoming.next().await.unwrap().is_ok()); + + // Fourth connection (permission denied) should break the stream + assert!(incoming.next().await.unwrap().is_err()); + + // Stream should be exhausted + assert!(incoming.next().await.is_none()); + } + + // Mock IO type for testing + struct MockIO; + + impl AsyncRead for MockIO { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for MockIO { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &[u8], + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(0)) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + } +} diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..6f55978 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,382 @@ +use crate::Error; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use std::{fs, io}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::TlsAcceptor; +use tokio_stream::{Stream, StreamExt}; + +/// Creates a stream of TLS-encrypted connections from a stream of TCP connections. +/// +/// This function takes a stream of TCP connections and a TLS acceptor, and produces +/// a new stream that yields TLS-encrypted connections. It handles both the successful +/// case of establishing a TLS connection and the error cases. +/// +/// # Type Parameters +/// +/// * `IO`: The I/O type representing the underlying TCP connection. It must implement +/// `AsyncRead`, `AsyncWrite`, `Unpin`, `Send`, and have a static lifetime. +/// +/// # Arguments +/// +/// * `tcp_stream`: A stream that yields `Result` items, representing incoming +/// TCP connections or errors. +/// * `tls`: A `TlsAcceptor` used to perform the TLS handshake on each TCP connection. +/// +/// # Returns +/// +/// A new `Stream` that yields `Result, Error>` items. +/// Each item is either a successfully established TLS connection or an error. +/// +/// # Error Handling +/// +/// - If the input `tcp_stream` yields an error, that error is propagated. +/// - If the TLS handshake fails, the error is wrapped in the crate's `Error` type. +#[inline] +pub fn serve_tls_incoming( + tcp_stream: impl Stream>, + tls: TlsAcceptor, +) -> impl Stream, Error>> +where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + // Transform each item in the TCP stream into a TLS stream + tcp_stream.then(move |result| { + // Clone the TLS acceptor for each connection + // This is necessary because the acceptor is moved into the async block + let tls = tls.clone(); + + async move { + match result { + // If the TCP connection was successfully established + Ok(io) => { + // Attempt to perform the TLS handshake + // If successful, return the TLS stream; otherwise, wrap the error + tls.accept(io).await.map_err(Error::from) + } + // If there was an error establishing the TCP connection, propagate it + Err(e) => Err(e), + } + } + }) +} + +/// Load the public certificate from a file. +/// +/// This function reads a PEM-encoded certificate file and returns a vector of +/// parsed certificates. +/// +/// # Arguments +/// +/// * `filename`: The path to the certificate file. +/// +/// # Returns +/// +/// A `Result` containing a vector of `CertificateDer` on success, or an `io::Error` on failure. +#[inline] +pub fn load_certs(filename: &str) -> io::Result>> { + // Open certificate file + let certfile = fs::File::open(filename)?; + let mut reader = io::BufReader::new(certfile); + + // Load and return certificates + // The `collect()` method is used to gather all certificates into a vector + rustls_pemfile::certs(&mut reader).collect() +} + +/// Load the private key from a file. +/// +/// This function reads a PEM-encoded private key file and returns the parsed private key. +/// +/// # Arguments +/// +/// * `filename`: The path to the private key file. +/// +/// # Returns +/// +/// A `Result` containing a `PrivateKeyDer` on success, or an `io::Error` on failure. +#[inline] +pub fn load_private_key(filename: &str) -> io::Result> { + // Open keyfile + let keyfile = fs::File::open(filename)?; + let mut reader = io::BufReader::new(keyfile); + + // Load and return a single private key + // The `?` operator is used for error propagation + rustls_pemfile::private_key(&mut reader)? + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found in file")) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tcp::serve_tcp_incoming; + use futures::StreamExt; + use rustls::pki_types::{CertificateDer, ServerName}; + use rustls::{ClientConfig, ServerConfig}; + use std::net::SocketAddr; + use std::sync::Arc; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::{TcpListener, TcpStream}; + use tokio_rustls::TlsAcceptor; + use tokio_stream::wrappers::TcpListenerStream; + use tracing::{debug, error, info, warn}; + + // Helper function to create a TLS acceptor for testing + async fn create_test_tls_acceptor() -> io::Result { + debug!("Creating test TLS acceptor"); + let certs = load_certs("examples/sample.pem")?; + let key = load_private_key("examples/sample.rsa")?; + + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| { + error!("Failed to create ServerConfig: {}", e); + io::Error::new(io::ErrorKind::Other, e) + })?; + + Ok(TlsAcceptor::from(Arc::new(config))) + } + + #[tokio::test] + async fn test_tls_incoming_success() -> Result<(), Box> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_tls_incoming_success"); + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + debug!("Server listening on {}", server_addr); + let incoming = TcpListenerStream::new(listener); + + let tls_acceptor = create_test_tls_acceptor().await?; + + // Use serve_tcp_incoming to handle TCP connections + let tcp_incoming = serve_tcp_incoming(incoming); + + // Spawn the server task + let server_task = tokio::spawn(async move { + debug!("Server task started"); + let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let result = tls_stream.next().await; + debug!("Server received connection: {:?}", result.is_some()); + result + }); + + // Connect to the server with a TLS client + let mut root_store = rustls::RootCertStore::empty(); + let certs = load_certs("examples/sample.pem")?; + root_store.add_parsable_certificates(certs); + + let client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + let connector = tokio_rustls::TlsConnector::from(Arc::new(client_config)); + + debug!("Client connecting to {}", server_addr); + let tcp_stream = TcpStream::connect(server_addr).await?; + let domain = ServerName::try_from("localhost")?; + let _client_stream = connector.connect(domain, tcp_stream).await?; + debug!("Client connected successfully"); + + // Wait for the server to accept the connection + let result = server_task + .await? + .ok_or("Server task completed without result")?; + match result { + Ok(_) => info!("TLS connection established successfully"), + Err(ref e) => error!("TLS connection failed: {}", e), + } + assert!(result.is_ok()); + + Ok(()) + } + + #[tokio::test] + async fn test_tls_incoming_invalid_cert() -> Result<(), Box> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_tls_incoming_invalid_cert"); + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + debug!("Server listening on {}", server_addr); + let incoming = TcpListenerStream::new(listener); + + // Create a TLS acceptor with an invalid certificate + let invalid_cert = vec![CertificateDer::from(vec![0; 32])]; // Invalid certificate + let key = load_private_key("examples/sample.rsa")?; + + // Expect this to fail and log the error + let config_result = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(invalid_cert, key); + + match config_result { + Ok(_) => warn!("ServerConfig creation unexpectedly succeeded with invalid cert"), + Err(e) => info!("ServerConfig creation failed as expected: {}", e), + } + + // Use a valid certificate for the server to allow the test to continue + let valid_certs = load_certs("examples/sample.pem")?; + let valid_key = load_private_key("examples/sample.rsa")?; + let config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(valid_certs, valid_key) + .expect("ServerConfig creation should succeed with valid cert"); + + let tls_acceptor = TlsAcceptor::from(Arc::new(config)); + + // Use serve_tcp_incoming to handle TCP connections + let tcp_incoming = serve_tcp_incoming(incoming); + + // Spawn the server task + let server_task = tokio::spawn(async move { + debug!("Server task started"); + let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let result = tls_stream.next().await; + debug!("Server received connection: {:?}", result.is_some()); + result + }); + + // Connect to the server with a TLS client that doesn't trust the server's certificate + let connector = tokio_rustls::TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(), + )); + + debug!("Client connecting to {}", server_addr); + let tcp_stream = TcpStream::connect(server_addr).await?; + let domain = ServerName::try_from("localhost")?; + + // This connection should fail due to certificate verification + let client_result = connector.connect(domain, tcp_stream).await; + match &client_result { + Ok(_) => warn!("Client connection succeeded unexpectedly"), + Err(e) => info!("Client connection failed as expected: {}", e), + } + assert!(client_result.is_err()); + + // The server should not encounter an error, but the connection should not be established + let server_result = server_task + .await? + .ok_or("Server task completed without result")?; + match &server_result { + Ok(_) => warn!("Server accepted connection unexpectedly"), + Err(e) => info!("Server did not establish connection as expected: {}", e), + } + assert!(server_result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_tls_incoming_client_hello_timeout() -> Result<(), Box> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_tls_incoming_client_hello_timeout"); + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + debug!("Server listening on {}", server_addr); + let incoming = TcpListenerStream::new(listener); + + let tls_acceptor = create_test_tls_acceptor().await?; + + // Use serve_tcp_incoming to handle TCP connections + let tcp_incoming = serve_tcp_incoming(incoming); + + // Spawn the server task + let server_task = tokio::spawn(async move { + debug!("Server task started"); + let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let result = + tokio::time::timeout(std::time::Duration::from_millis(10), tls_stream.next()).await; + debug!("Server task completed with result: {:?}", result.is_err()); + result + }); + + // Connect with a regular TCP client (no TLS handshake) + debug!("Client connecting with plain TCP to {}", server_addr); + let _tcp_stream = TcpStream::connect(server_addr).await?; + debug!("Client connected with plain TCP"); + + // The server task should timeout + let result = server_task.await?; + match result { + Ok(_) => warn!("Server did not timeout as expected"), + Err(ref e) => info!("Server timed out as expected: {}", e), + } + assert!(result.is_err()); // Timeout error + + Ok(()) + } + + #[tokio::test] + async fn test_load_certs() -> io::Result<()> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_load_certs"); + let certs = load_certs("examples/sample.pem")?; + debug!("Loaded {} certificates", certs.len()); + assert!(!certs.is_empty(), "Certificate file should not be empty"); + Ok(()) + } + + #[tokio::test] + async fn test_load_private_key() -> io::Result<()> { + let _guard = tracing_subscriber::fmt() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + info!("Starting test_load_private_key"); + let key = load_private_key("examples/sample.rsa")?; + debug!("Loaded private key, length: {}", key.secret_der().len()); + assert!( + !key.secret_der().is_empty(), + "Private key should not be empty" + ); + Ok(()) + } + + // Simulating the tls_incoming function for testing purposes + // Replace this with your actual implementation + fn tls_incoming( + incoming: impl Stream> + Send + 'static, + tls_acceptor: TlsAcceptor, + ) -> impl Stream, Error>> + Send + 'static + where + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + Box::pin(incoming.then(move |result| { + let tls_acceptor = tls_acceptor.clone(); + async move { + match result { + Ok(io) => { + debug!("Accepting TLS connection"); + let accept_result = tls_acceptor.accept(io).await.map_err(Error::from); + match &accept_result { + Ok(_) => debug!("TLS connection accepted successfully"), + Err(e) => warn!("Failed to accept TLS connection: {}", e), + } + accept_result + } + Err(e) => { + warn!("Error in incoming connection: {}", e); + Err(e) + } + } + } + })) + } +} diff --git a/src/tls_openssl/future.rs b/src/tls_openssl/future.rs deleted file mode 100644 index 54f64b6..0000000 --- a/src/tls_openssl/future.rs +++ /dev/null @@ -1,183 +0,0 @@ -//! Future types. -//! -//! This module provides the futures and supporting types for integrating OpenSSL with a hyper/tokio HTTP/TLS server. -//! `OpenSSLAcceptorFuture` is the main public-facing type which wraps around the core logic of establishing an SSL/TLS -//! connection. - -use super::OpenSSLConfig; -use pin_project_lite::pin_project; -use std::io::{Error, ErrorKind}; -use std::time::Duration; -use std::{ - fmt, - future::Future, - io, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::time::{timeout, Timeout}; - -use openssl::ssl::Ssl; -use tokio_openssl::SslStream; - -// The OpenSSLAcceptorFuture encapsulates the asynchronous logic of accepting an SSL/TLS connection. -pin_project! { - /// A Future for establishing an SSL/TLS connection using `OpenSSLAcceptor`. - /// - /// This wraps around the process of asynchronously establishing an SSL/TLS connection via OpenSSL. - /// It waits for the inner non-TLS connection to be established, and then handles the TLS handshake. - pub struct OpenSSLAcceptorFuture { - #[pin] - inner: AcceptFuture, // Inner future which manages the state machine of accepting connections. - config: Option, // The SSL/TLS configuration to use for the handshake. - } -} - -impl OpenSSLAcceptorFuture { - /// Constructs a new `OpenSSLAcceptorFuture`. - /// - /// # Arguments - /// - `future`: The initial future that handles the non-TLS accept phase. - /// - `config`: SSL/TLS configuration. - /// - `handshake_timeout`: Maximum duration allowed for the TLS handshake. - pub(crate) fn new(future: F, config: OpenSSLConfig, handshake_timeout: Duration) -> Self { - let inner = AcceptFuture::InnerAccepting { - future, - handshake_timeout, - }; - let config = Some(config); - - Self { inner, config } - } -} - -impl fmt::Debug for OpenSSLAcceptorFuture { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenSSLAcceptorFuture").finish() - } -} - -// A future for performing the SSL/TLS handshake using an `SslStream`. -pin_project! { - struct TlsAccept { - #[pin] - tls_stream: Option>, // The SSL/TLS stream on which the handshake will be performed. - } -} - -impl Future for TlsAccept -where - I: AsyncRead + AsyncWrite + Unpin, // The inner type must support asynchronous reading and writing. -{ - type Output = io::Result>; // The result will be an `SslStream` if the handshake is successful. - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - match this - .tls_stream - .as_mut() - .as_pin_mut() - .map(|inner| inner.poll_accept(cx)) - .expect("tlsaccept polled after ready") - { - Poll::Ready(Ok(())) => { - let tls_stream = this.tls_stream.take().expect("tls stream vanished?"); - Poll::Ready(Ok(tls_stream)) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(Error::new(ErrorKind::Other, e))), - Poll::Pending => Poll::Pending, - } - } -} - -// Enumerates the possible states of the accept process, either waiting for the inner non-TLS -// connection to be accepted, or performing the TLS handshake. -pin_project! { - #[project = AcceptFutureProj] - enum AcceptFuture { - // Waiting for the non-TLS connection to be accepted. - InnerAccepting { - #[pin] - future: F, // The future representing the non-TLS accept phase. - handshake_timeout: Duration, // Maximum duration for the TLS handshake. - }, - // Performing the TLS handshake. - TlsAccepting { - #[pin] - future: Timeout>, // Future that represents the TLS handshake, with a timeout. - service: Option, // The underlying service that will handle the request after the TLS handshake. - } - } -} - -// Main implementation of the future for `OpenSSLAcceptor`. -impl Future for OpenSSLAcceptorFuture -where - F: Future>, // The initial non-TLS accept future. - I: AsyncRead + AsyncWrite + Unpin, // The inner type must support asynchronous reading and writing. -{ - type Output = io::Result<(SslStream, S)>; // The output will be an `SslStream` and the service to handle the request. - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - // The inner future here is what is doing the lower level accept, such as - // our tcp socket. - // - // So we poll on that first, when it's ready we then swap our the inner future to - // one waiting for our ssl layer to accept/install. - // - // Then once that's ready we can then wrap and provide the SslStream back out. - - // This loop exists to allow the Poll::Ready from InnerAccept on complete - // to re-poll immediately. Otherwise all other paths are immediate returns. - loop { - match this.inner.as_mut().project() { - AcceptFutureProj::InnerAccepting { - future, - handshake_timeout, - } => match future.poll(cx) { - Poll::Ready(Ok((stream, service))) => { - let server_config = this.config.take().expect( - "config is not set. this is a bug in hyper-server, please report", - ); - - // Change to poll::ready(err) - let ssl = Ssl::new(server_config.acceptor.context()).unwrap(); - - let tls_stream = SslStream::new(ssl, stream).unwrap(); - let future = TlsAccept { - tls_stream: Some(tls_stream), - }; - - let service = Some(service); - let handshake_timeout = *handshake_timeout; - - this.inner.set(AcceptFuture::TlsAccepting { - future: timeout(handshake_timeout, future), - service, - }); - // the loop is now triggered to immediately poll on - // ssl stream accept. - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - }, - - AcceptFutureProj::TlsAccepting { future, service } => match future.poll(cx) { - Poll::Ready(Ok(Ok(stream))) => { - let service = service.take().expect("future polled after ready"); - return Poll::Ready(Ok((stream, service))); - } - Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(Err(timeout)) => { - return Poll::Ready(Err(Error::new(ErrorKind::TimedOut, timeout))) - } - Poll::Pending => return Poll::Pending, - }, - } - } - } -} diff --git a/src/tls_openssl/mod.rs b/src/tls_openssl/mod.rs deleted file mode 100644 index f108517..0000000 --- a/src/tls_openssl/mod.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Tls implementation using [`openssl`] -//! -//! # Example -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use hyper_server::tls_openssl::OpenSSLConfig; -//! use std::net::SocketAddr; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let config = OpenSSLConfig::from_pem_file( -//! "examples/self-signed-certs/cert.pem", -//! "examples/self-signed-certs/key.pem", -//! ) -//! .unwrap(); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! hyper_server::bind_openssl(addr, config) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` - -use self::future::OpenSSLAcceptorFuture; -use crate::{ - accept::{Accept, DefaultAcceptor}, - server::Server, -}; -use openssl::ssl::{ - Error as OpenSSLError, SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod, -}; -use std::{convert::TryFrom, fmt, net::SocketAddr, path::Path, sync::Arc, time::Duration}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_openssl::SslStream; - -pub mod future; - -/// Binds a TLS server using OpenSSL to the specified address with the given configuration. -/// -/// The server is configured to accept TLS encrypted connections. -/// -/// # Arguments -/// -/// * `addr`: The address to which the server will bind. -/// * `config`: The TLS configuration for the server. -/// -/// # Returns -/// -/// A configured `Server` instance ready to be run. -pub fn bind_openssl(addr: SocketAddr, config: OpenSSLConfig) -> Server { - let acceptor = OpenSSLAcceptor::new(config); - Server::bind(addr).acceptor(acceptor) -} - -/// Represents a TLS acceptor that uses OpenSSL for cryptographic operations. -/// -/// This structure is used for handling TLS encrypted connections. -/// -/// The acceptor is backed by OpenSSL, and is used to upgrade incoming non-secure connections -/// to secure TLS connections. -/// -/// The default TLS handshake timeout is set to 10 seconds. -#[derive(Clone)] -pub struct OpenSSLAcceptor { - inner: A, - config: OpenSSLConfig, - handshake_timeout: Duration, -} - -impl OpenSSLAcceptor { - /// Constructs a new instance of the OpenSSL acceptor. - /// - /// # Arguments - /// - /// * `config`: Configuration options for the OpenSSL server. - pub fn new(config: OpenSSLConfig) -> Self { - let inner = DefaultAcceptor::new(); - - // Default handshake timeout is 10 seconds. - #[cfg(not(test))] - let handshake_timeout = Duration::from_secs(10); - - // For tests, use a shorter timeout to avoid unnecessary delays. - #[cfg(test)] - let handshake_timeout = Duration::from_secs(1); - - Self { - inner, - config, - handshake_timeout, - } - } - - /// Overrides the default TLS handshake timeout. - /// - /// # Arguments - /// - /// * `val`: The duration to set as the new handshake timeout. - /// - /// # Returns - /// - /// A modified version of the current acceptor with the new timeout value. - pub fn handshake_timeout(mut self, val: Duration) -> Self { - self.handshake_timeout = val; - self - } -} - -impl Accept for OpenSSLAcceptor -where - A: Accept, - A::Stream: AsyncRead + AsyncWrite + Unpin, -{ - type Stream = SslStream; - type Service = A::Service; - type Future = OpenSSLAcceptorFuture; - - /// Handles the incoming stream, initiates a TLS handshake, and upgrades it to a secure connection. - fn accept(&self, stream: I, service: S) -> Self::Future { - let inner_future = self.inner.accept(stream, service); - let config = self.config.clone(); - - OpenSSLAcceptorFuture::new(inner_future, config, self.handshake_timeout) - } -} - -impl fmt::Debug for OpenSSLAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenSSLAcceptor").finish() - } -} - -/// Represents configuration options for an OpenSSL-based server. -/// -/// This configuration is used when constructing a new `OpenSSLAcceptor`. -#[derive(Clone)] -pub struct OpenSSLConfig { - acceptor: Arc, -} - -impl OpenSSLConfig { - /// Creates a new configuration using a PEM formatted certificate and key. - /// - /// # Arguments - /// - /// * `cert`: Path to the PEM-formatted certificate file. - /// * `key`: Path to the PEM-formatted private key file. - /// - /// # Returns - /// - /// A `Result` that contains an `OpenSSLConfig` or an `OpenSSLError`. - pub fn from_pem_file, B: AsRef>( - cert: A, - key: B, - ) -> Result { - let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?; - - tls_builder.set_certificate_file(cert, SslFiletype::PEM)?; - tls_builder.set_private_key_file(key, SslFiletype::PEM)?; - tls_builder.check_private_key()?; - - let acceptor = Arc::new(tls_builder.build()); - - Ok(OpenSSLConfig { acceptor }) - } - - /// Creates a new configuration using a PEM formatted certificate chain and key. - /// - /// # Arguments - /// - /// * `chain`: Path to the PEM-formatted certificate chain file. - /// * `key`: Path to the PEM-formatted private key file. - /// - /// # Returns - /// - /// A `Result` that contains an `OpenSSLConfig` or an `OpenSSLError`. - pub fn from_pem_chain_file, B: AsRef>( - chain: A, - key: B, - ) -> Result { - let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls())?; - - tls_builder.set_certificate_chain_file(chain)?; - tls_builder.set_private_key_file(key, SslFiletype::PEM)?; - tls_builder.check_private_key()?; - - let acceptor = Arc::new(tls_builder.build()); - - Ok(OpenSSLConfig { acceptor }) - } -} - -impl TryFrom for OpenSSLConfig { - type Error = OpenSSLError; - - /// Constructs [`OpenSSLConfig`] from an [`SslAcceptorBuilder`]. This allows precise - /// control over the settings that will be used by OpenSSL in this server. - /// - /// # Example - /// ``` - /// use hyper_server::tls_openssl::OpenSSLConfig; - /// use openssl::ssl::{SslAcceptor, SslMethod}; - /// use std::convert::TryFrom; - /// - /// #[tokio::main] - /// async fn main() { - /// let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()) - /// .unwrap(); - /// // Set configurations like set_certificate_chain_file or - /// // set_private_key_file. - /// // let tls_builder.set_ ... ; - - /// let _config = OpenSSLConfig::try_from(tls_builder); - /// } - /// ``` - fn try_from(tls_builder: SslAcceptorBuilder) -> Result { - tls_builder.check_private_key()?; - let acceptor = Arc::new(tls_builder.build()); - Ok(OpenSSLConfig { acceptor }) - } -} - -impl fmt::Debug for OpenSSLConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenSSLConfig").finish() - } -} - -#[cfg(test)] -pub(crate) mod tests { - use crate::{ - handle::Handle, - tls_openssl::{self, OpenSSLConfig}, - }; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use std::{io, net::SocketAddr, time::Duration}; - use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; - use tower::{Service, ServiceExt}; - - use openssl::ssl::{Ssl, SslConnector, SslMethod, SslVerifyMode}; - use std::pin::Pin; - use tokio_openssl::SslStream; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn test_shutdown() { - let (handle, _server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.shutdown(); - - let response_future_result = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await; - - assert!(response_future_result.is_err()); - - // Connection task should finish soon. - let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); - } - - #[tokio::test] - async fn test_graceful_shutdown() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.graceful_shutdown(None); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Disconnect client. - conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - #[tokio::test] - async fn test_graceful_shutdown_timed() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - handle.graceful_shutdown(Some(Duration::from_millis(250))); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Don't disconnect client. - // conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = OpenSSLConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .unwrap(); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_openssl::bind_openssl(addr, config) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { - let stream = TcpStream::connect(addr).await.unwrap(); - let tls_stream = tls_connector(dns_name(), stream).await; - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream { - let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap(); - tls_parms.set_verify(SslVerifyMode::NONE); - let hostname_owned = hostname.to_string(); - tls_parms.set_client_hello_callback(move |ssl_ref, _ssl_alert| { - ssl_ref - .set_hostname(hostname_owned.as_str()) - .map(|()| openssl::ssl::ClientHelloResponse::SUCCESS) - }); - let tls_parms = tls_parms.build(); - - let ssl = Ssl::new(tls_parms.context()).unwrap(); - let mut tls_stream = SslStream::new(ssl, stream).unwrap(); - - SslStream::connect(Pin::new(&mut tls_stream)).await.unwrap(); - - tls_stream - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) fn dns_name() -> &'static str { - "localhost" - } -} diff --git a/src/tls_rustls/future.rs b/src/tls_rustls/future.rs deleted file mode 100644 index 843046f..0000000 --- a/src/tls_rustls/future.rs +++ /dev/null @@ -1,130 +0,0 @@ -//! Module containing futures specific to the `rustls` TLS acceptor for the server. -//! -//! This module primarily provides the `RustlsAcceptorFuture` which is responsible for performing the TLS handshake -//! using the `rustls` library. - -use crate::tls_rustls::RustlsConfig; -use pin_project_lite::pin_project; -use std::io::{Error, ErrorKind}; -use std::time::Duration; -use std::{ - fmt, - future::Future, - io, - pin::Pin, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::time::{timeout, Timeout}; -use tokio_rustls::{server::TlsStream, Accept, TlsAcceptor}; - -pin_project! { - /// A future representing the asynchronous TLS handshake using the `rustls` library. - /// - /// Once completed, it yields a `TlsStream` which is a wrapper around the actual underlying stream, with - /// encryption and decryption operations applied to it. - pub struct RustlsAcceptorFuture { - #[pin] - inner: AcceptFuture, - config: Option, - } -} - -impl RustlsAcceptorFuture { - /// Constructs a new `RustlsAcceptorFuture`. - /// - /// * `future`: The future that resolves to the original non-encrypted stream. - /// * `config`: The rustls configuration to use for the handshake. - /// * `handshake_timeout`: The maximum duration to wait for the handshake to complete. - pub(crate) fn new(future: F, config: RustlsConfig, handshake_timeout: Duration) -> Self { - let inner = AcceptFuture::Inner { - future, - handshake_timeout, - }; - let config = Some(config); - - Self { inner, config } - } -} - -impl fmt::Debug for RustlsAcceptorFuture { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RustlsAcceptorFuture").finish() - } -} - -pin_project! { - /// Internal states of the handshake process. - #[project = AcceptFutureProj] - enum AcceptFuture { - /// Initial state where we have a future that resolves to the original non-encrypted stream. - Inner { - #[pin] - future: F, - handshake_timeout: Duration, - }, - /// State after receiving the stream where the handshake is performed asynchronously. - Accept { - #[pin] - future: Timeout>, - service: Option, - }, - } -} - -impl Future for RustlsAcceptorFuture -where - F: Future>, - I: AsyncRead + AsyncWrite + Unpin, -{ - type Output = io::Result<(TlsStream, S)>; - - /// Advances the handshake state machine. - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - - loop { - match this.inner.as_mut().project() { - AcceptFutureProj::Inner { - future, - handshake_timeout, - } => { - // Poll the future to get the original stream. - match future.poll(cx) { - Poll::Ready(Ok((stream, service))) => { - let server_config = this.config - .take() - .expect("config is not set. this is a bug in hyper-server, please report") - .get_inner(); - - let acceptor = TlsAcceptor::from(server_config); - let future = acceptor.accept(stream); - - let service = Some(service); - let handshake_timeout = *handshake_timeout; - - this.inner.set(AcceptFuture::Accept { - future: timeout(handshake_timeout, future), - service, - }); - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - } - } - AcceptFutureProj::Accept { future, service } => match future.poll(cx) { - Poll::Ready(Ok(Ok(stream))) => { - let service = service.take().expect("future polled after ready"); - - return Poll::Ready(Ok((stream, service))); - } - Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)), - Poll::Ready(Err(timeout)) => { - return Poll::Ready(Err(Error::new(ErrorKind::TimedOut, timeout))) - } - Poll::Pending => return Poll::Pending, - }, - } - } - } -} diff --git a/src/tls_rustls/mod.rs b/src/tls_rustls/mod.rs deleted file mode 100644 index af7e391..0000000 --- a/src/tls_rustls/mod.rs +++ /dev/null @@ -1,579 +0,0 @@ -//! Tls implementation using [`rustls`]. -//! -//! # Example -//! -//! ```rust,no_run -//! use axum::{routing::get, Router}; -//! use hyper_server::tls_rustls::RustlsConfig; -//! use std::net::SocketAddr; -//! -//! #[tokio::main] -//! async fn main() { -//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); -//! -//! let config = RustlsConfig::from_pem_file( -//! "examples/self-signed-certs/cert.pem", -//! "examples/self-signed-certs/key.pem", -//! ) -//! .await -//! .unwrap(); -//! -//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -//! println!("listening on {}", addr); -//! hyper_server::bind_rustls(addr, config) -//! .serve(app.into_make_service()) -//! .await -//! .unwrap(); -//! } -//! ``` - -use self::future::RustlsAcceptorFuture; -use crate::{ - accept::{Accept, DefaultAcceptor}, - server::{io_other, Server}, -}; -use arc_swap::ArcSwap; -use rustls::{Certificate, PrivateKey, ServerConfig}; -use std::time::Duration; -use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - task::spawn_blocking, -}; -use tokio_rustls::server::TlsStream; - -/// Sub-module that contains re-exported public interfaces. -pub(crate) mod export { - use super::{RustlsAcceptor, RustlsConfig, Server, SocketAddr}; - - /// Creates a TLS server that binds to the provided address using the rustls library. - #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))] - pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server { - super::bind_rustls(addr, config) - } - - /// Creates a TLS server from an existing `std::net::TcpListener` using the rustls library. - #[cfg_attr(docsrs, doc(cfg(feature = "tls-rustls")))] - pub fn from_tcp_rustls( - listener: std::net::TcpListener, - config: RustlsConfig, - ) -> Server { - let acceptor = RustlsAcceptor::new(config); - - Server::from_tcp(listener).acceptor(acceptor) - } -} - -pub mod future; - -/// Helper function to create a TLS server bound to a provided address. -pub fn bind_rustls(addr: SocketAddr, config: RustlsConfig) -> Server { - let acceptor = RustlsAcceptor::new(config); - - Server::bind(addr).acceptor(acceptor) -} - -/// Helper function to create a TLS server from an existing `std::net::TcpListener`. -pub fn from_tcp_rustls( - listener: std::net::TcpListener, - config: RustlsConfig, -) -> Server { - let acceptor = RustlsAcceptor::new(config); - - Server::from_tcp(listener).acceptor(acceptor) -} - -/// A TLS acceptor implementation using the rustls library. -#[derive(Clone)] -pub struct RustlsAcceptor { - inner: A, - config: RustlsConfig, - handshake_timeout: Duration, -} - -impl RustlsAcceptor { - /// Constructs a new rustls acceptor with the given configuration. - pub fn new(config: RustlsConfig) -> Self { - let inner = DefaultAcceptor::new(); - - // Default handshake timeout is set to 10 seconds. - // In test mode, this is reduced to 1 second to avoid waiting too long. - #[cfg(not(test))] - let handshake_timeout = Duration::from_secs(10); - #[cfg(test)] - let handshake_timeout = Duration::from_secs(1); - - Self { - inner, - config, - handshake_timeout, - } - } - - /// Allows overriding the default TLS handshake timeout. - pub fn handshake_timeout(mut self, val: Duration) -> Self { - self.handshake_timeout = val; - self - } -} - -impl RustlsAcceptor { - /// Replaces the inner acceptor with a custom acceptor. - pub fn acceptor(self, acceptor: Acceptor) -> RustlsAcceptor { - RustlsAcceptor { - inner: acceptor, - config: self.config, - handshake_timeout: self.handshake_timeout, - } - } -} - -// Implementation to accept incoming TLS connections using rustls. -impl Accept for RustlsAcceptor -where - A: Accept, - A::Stream: AsyncRead + AsyncWrite + Unpin, -{ - type Stream = TlsStream; - type Service = A::Service; - type Future = RustlsAcceptorFuture; - - fn accept(&self, stream: I, service: S) -> Self::Future { - let inner_future = self.inner.accept(stream, service); - let config = self.config.clone(); - - RustlsAcceptorFuture::new(inner_future, config, self.handshake_timeout) - } -} - -impl fmt::Debug for RustlsAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RustlsAcceptor").finish() - } -} - -/// Represents the rustls configuration for the server. -#[derive(Clone)] -pub struct RustlsConfig { - inner: Arc>, -} - -// The `RustlsConfig` structure represents configuration data for rustls. -impl RustlsConfig { - /// Create a new `RustlsConfig` from an `Arc`. - /// - /// Important: This method does not set ALPN protocols (like `http/1.1` or `h2`) automatically. - /// ALPN protocols need to be set manually when using this method. - pub fn from_config(config: Arc) -> Self { - let inner = Arc::new(ArcSwap::new(config)); - Self { inner } - } - - /// Create a `RustlsConfig` from DER-encoded data. - /// DER is a binary format for encoding data, commonly used for certificates and keys. - /// - /// `cert` is expected to be a DER-encoded X.509 certificate. - /// `key` is expected to be a DER-encoded ASN.1 format private key, either in PKCS#8 or PKCS#1 format. - pub async fn from_der(cert: Vec>, key: Vec) -> io::Result { - let server_config = spawn_blocking(|| config_from_der(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(ArcSwap::from_pointee(server_config)); - Ok(Self { inner }) - } - - /// Create a `RustlsConfig` from PEM-formatted data. - /// PEM is a text-based format used to encode binary data like certificates and keys. - /// - /// Both `cert` and `key` must be provided in PEM format. - pub async fn from_pem(cert: Vec, key: Vec) -> io::Result { - let server_config = spawn_blocking(|| config_from_pem(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(ArcSwap::from_pointee(server_config)); - Ok(Self { inner }) - } - - /// Create a `RustlsConfig` by reading PEM-formatted files. - /// - /// The contents of the provided certificate and private key files must be in PEM format. - pub async fn from_pem_file(cert: impl AsRef, key: impl AsRef) -> io::Result { - let server_config = config_from_pem_file(cert, key).await?; - let inner = Arc::new(ArcSwap::from_pointee(server_config)); - Ok(Self { inner }) - } - - /// Retrieve the inner `Arc` from the `RustlsConfig`. - pub fn get_inner(&self) -> Arc { - self.inner.load_full() - } - - /// Update (or reload) the `RustlsConfig` with a new `Arc`. - pub fn reload_from_config(&self, config: Arc) { - self.inner.store(config); - } - - /// Reload the `RustlsConfig` from provided DER-encoded data. - /// - /// As with the `from_der` method, `cert` must be DER-encoded X.509 and `key` - /// should be in either PKCS#8 or PKCS#1 DER-encoded ASN.1 format. - pub async fn reload_from_der(&self, cert: Vec>, key: Vec) -> io::Result<()> { - let server_config = spawn_blocking(|| config_from_der(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(server_config); - self.inner.store(inner); - Ok(()) - } - - /// Reload the `RustlsConfig` using provided PEM-formatted data. - pub async fn reload_from_pem(&self, cert: Vec, key: Vec) -> io::Result<()> { - let server_config = spawn_blocking(|| config_from_pem(cert, key)) - .await - .unwrap()?; - let inner = Arc::new(server_config); - self.inner.store(inner); - Ok(()) - } - - /// Reload the `RustlsConfig` from provided PEM-formatted files. - pub async fn reload_from_pem_file( - &self, - cert: impl AsRef, - key: impl AsRef, - ) -> io::Result<()> { - let server_config = config_from_pem_file(cert, key).await?; - let inner = Arc::new(server_config); - self.inner.store(inner); - Ok(()) - } -} - -// This provides a debug representation for the `RustlsConfig`. -impl fmt::Debug for RustlsConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RustlsConfig").finish() - } -} - -// Helper function to convert DER-encoded certificate and key into rustls's `ServerConfig`. -fn config_from_der(cert: Vec>, key: Vec) -> io::Result { - // Convert the raw bytes into rustls's Certificate and PrivateKey structures. - let cert = cert.into_iter().map(Certificate).collect(); - let key = PrivateKey(key); - - // Construct the ServerConfig. - let mut config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(cert, key) - .map_err(io_other)?; - - // Set ALPN protocols for the configuration. - config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - Ok(config) -} - -// Helper function to convert PEM-formatted certificate and key into rustls' `ServerConfig`. -fn config_from_pem(cert: Vec, key: Vec) -> io::Result { - use rustls_pemfile::Item; - - // Parse PEM formatted data into rustls structures. - let cert = rustls_pemfile::certs(&mut cert.as_ref())?; - let key = match rustls_pemfile::read_one(&mut key.as_ref())? { - Some(Item::RSAKey(key)) | Some(Item::PKCS8Key(key)) | Some(Item::ECKey(key)) => key, - _ => return Err(io_other("private key format not supported")), - }; - - config_from_der(cert, key) -} - -// Helper function to read PEM-formatted files and convert them into rustls' ServerConfig. -async fn config_from_pem_file( - cert: impl AsRef, - key: impl AsRef, -) -> io::Result { - // Read the PEM files asynchronously. - let cert = tokio::fs::read(cert.as_ref()).await?; - let key = tokio::fs::read(key.as_ref()).await?; - - config_from_pem(cert, key) -} - -#[cfg(test)] -pub(crate) mod tests { - use crate::{ - handle::Handle, - tls_rustls::{self, RustlsConfig}, - }; - use axum::{routing::get, Router}; - use bytes::Bytes; - use http::{response, Request}; - use hyper::{ - client::conn::{handshake, SendRequest}, - Body, - }; - use rustls::{ - client::{ServerCertVerified, ServerCertVerifier}, - Certificate, ClientConfig, ServerName, - }; - use std::{ - convert::TryFrom, - io, - net::SocketAddr, - sync::Arc, - time::{Duration, SystemTime}, - }; - use tokio::time::sleep; - use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; - use tokio_rustls::TlsConnector; - use tower::{Service, ServiceExt}; - - #[tokio::test] - async fn start_and_request() { - let (_handle, _server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - } - - #[tokio::test] - async fn tls_timeout() { - let (handle, _server_task, addr) = start_server().await; - assert_eq!(handle.connection_count(), 0); - - // We intentionally avoid driving a TLS handshake to completion. - let _stream = TcpStream::connect(addr).await.unwrap(); - - sleep(Duration::from_millis(500)).await; - assert_eq!(handle.connection_count(), 1); - - tokio::time::sleep(Duration::from_millis(1000)).await; - // Timeout defaults to 1s during testing, and we have waited 1.5 seconds. - assert_eq!(handle.connection_count(), 0); - } - - #[tokio::test] - async fn test_reload() { - let handle = Handle::new(); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - let server_handle = handle.clone(); - let rustls_config = config.clone(); - tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_rustls::bind_rustls(addr, rustls_config) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - let cert_a = get_first_cert(addr).await; - let mut cert_b = get_first_cert(addr).await; - - assert_eq!(cert_a, cert_b); - - config - .reload_from_pem_file( - "examples/self-signed-certs/reload/cert.pem", - "examples/self-signed-certs/reload/key.pem", - ) - .await - .unwrap(); - - cert_b = get_first_cert(addr).await; - - assert_ne!(cert_a, cert_b); - - config - .reload_from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await - .unwrap(); - - cert_b = get_first_cert(addr).await; - - assert_eq!(cert_a, cert_b); - } - - #[tokio::test] - async fn test_shutdown() { - let (handle, _server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.shutdown(); - - let response_future_result = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await; - - assert!(response_future_result.is_err()); - - // Connection task should finish soon. - let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); - } - - #[tokio::test] - async fn test_graceful_shutdown() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, conn) = connect(addr).await; - - handle.graceful_shutdown(None); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Disconnect client. - conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - #[tokio::test] - async fn test_graceful_shutdown_timed() { - let (handle, server_task, addr) = start_server().await; - - let (mut client, _conn) = connect(addr).await; - - handle.graceful_shutdown(Some(Duration::from_millis(250))); - - let (_parts, body) = send_empty_request(&mut client).await; - - assert_eq!(body.as_ref(), b"Hello, world!"); - - // Don't disconnect client. - // conn.abort(); - - // Server task should finish soon. - let server_result = timeout(Duration::from_secs(1), server_task) - .await - .unwrap() - .unwrap(); - - assert!(server_result.is_ok()); - } - - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { - let handle = Handle::new(); - - let server_handle = handle.clone(); - let server_task = tokio::spawn(async move { - let app = Router::new().route("/", get(|| async { "Hello, world!" })); - - let config = RustlsConfig::from_pem_file( - "examples/self-signed-certs/cert.pem", - "examples/self-signed-certs/key.pem", - ) - .await?; - - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - tls_rustls::bind_rustls(addr, config) - .handle(server_handle) - .serve(app.into_make_service()) - .await - }); - - let addr = handle.listening().await.unwrap(); - - (handle, server_task, addr) - } - - async fn get_first_cert(addr: SocketAddr) -> Certificate { - let stream = TcpStream::connect(addr).await.unwrap(); - let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap(); - - let (_io, client_connection) = tls_stream.into_inner(); - - client_connection.peer_certificates().unwrap()[0].clone() - } - - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { - let stream = TcpStream::connect(addr).await.unwrap(); - let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap(); - - let (send_request, connection) = handshake(tls_stream).await.unwrap(); - - let task = tokio::spawn(async move { - let _ = connection.await; - }); - - (send_request, task) - } - - async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { - let (parts, body) = client - .ready() - .await - .unwrap() - .call(Request::new(Body::empty())) - .await - .unwrap() - .into_parts(); - let body = hyper::body::to_bytes(body).await.unwrap(); - - (parts, body) - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) fn tls_connector() -> TlsConnector { - struct NoVerify; - - impl ServerCertVerifier for NoVerify { - fn verify_server_cert( - &self, - _end_entity: &Certificate, - _intermediates: &[Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: SystemTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - } - - let mut client_config = ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(Arc::new(NoVerify)) - .with_no_client_auth(); - - client_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - TlsConnector::from(Arc::new(client_config)) - } - - /// Used in `proxy-protocol` feature tests. - pub(crate) fn dns_name() -> ServerName { - ServerName::try_from("localhost").unwrap() - } -}