diff --git a/Cargo.toml b/Cargo.toml index d2da35f..5422092 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,10 +26,10 @@ hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", pin-project = "1.1.5" pprof = { version = "0.13.0", features = ["flamegraph"], optional = true } ring = "0.17.8" -rustls = { version = "0.23.13", features = ["zlib"] } +rustls = { version = "0.23.13", features = ["zlib", "aws_lc_rs"] } rustls-pemfile = "2.1.3" tokio = { version = "1.40.0", features = ["net", "macros", "rt-multi-thread", "time"] } -tokio-rustls = "0.26.0" +tokio-rustls = { version = "0.26.0", features = ["aws_lc_rs"] } tokio-stream = { version = "0.1.16", features = ["net"] } tokio-util = "0.7.12" tower = { version = "0.5.1", features = ["util"] } diff --git a/src/http.rs b/src/http.rs index b497917..09bf19f 100644 --- a/src/http.rs +++ b/src/http.rs @@ -570,6 +570,16 @@ mod tests { // Utility functions + fn init_crypto_provider() { + // This and some other helper functions need a bit of DRY + match rustls::crypto::aws_lc_rs::default_provider().install_default() { + Ok(_) => debug!("Default crypto provider installed successfully"), + Err(_) => { + // Crypto provider is already installed + } + } + } + async fn echo(req: Request) -> Result>, hyper::Error> { match (req.method(), req.uri().path()) { (&hyper::Method::GET, "/") => { @@ -812,6 +822,7 @@ mod tests { #[tokio::test] async fn test_https_connection() { + init_crypto_provider(); let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let (incoming, server_addr) = setup_test_server(addr).await; @@ -865,6 +876,7 @@ mod tests { #[tokio::test] async fn test_https_invalid_client_cert() { + init_crypto_provider(); let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let (incoming, server_addr) = setup_test_server(addr).await; @@ -905,6 +917,7 @@ mod tests { } #[tokio::test] async fn test_https_graceful_shutdown() { + init_crypto_provider(); let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let (incoming, server_addr) = setup_test_server(addr).await; diff --git a/src/tcp.rs b/src/tcp.rs index 1d26d4f..8934ad2 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -1,5 +1,5 @@ use crate::error::handle_accept_error; -use crate::Error; +use crate::Error as TransportError; use std::ops::ControlFlow; use std::pin::pin; use tokio::io::{AsyncRead, AsyncWrite}; @@ -55,10 +55,10 @@ use tokio_stream::{Stream, StreamExt}; #[inline] pub fn serve_tcp_incoming( incoming: impl Stream> + Send + 'static, -) -> impl Stream> +) -> impl Stream> where IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, - IE: Into + Send + 'static, + IE: Into + Send + 'static, { async_stream::stream! { // We pin the stream on the stack to ensure that it's safe to diff --git a/src/tls.rs b/src/tls.rs index 6f55978..0a435a2 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,63 +1,124 @@ -use crate::Error; +use crate::error::handle_accept_error; +use crate::Error as TransportError; +use futures::stream::StreamExt; use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use std::ops::ControlFlow; +use std::pin::pin; use std::{fs, io}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; -use tokio_stream::{Stream, StreamExt}; +use tokio_stream::Stream; /// Creates a stream of TLS-encrypted connections from a stream of TCP connections. /// /// This function takes a stream of TCP connections and a TLS acceptor, and produces -/// a new stream that yields TLS-encrypted connections. It handles both the successful -/// case of establishing a TLS connection and the error cases. +/// a new stream that yields TLS-encrypted connections. It handles both successful +/// TLS handshakes and various error scenarios, providing a robust way to upgrade +/// TCP connections to TLS. /// /// # Type Parameters /// -/// * `IO`: The I/O type representing the underlying TCP connection. It must implement -/// `AsyncRead`, `AsyncWrite`, `Unpin`, `Send`, and have a static lifetime. +/// * [`IO`]: The I/O type representing the underlying TCP connection. It must implement +/// [`AsyncRead`], [`AsyncWrite`], [`Unpin`], [`Send`], and have a static lifetime. +/// * [`IE`]: The error type of the incoming TCP stream, which must be convertible to +/// the crate's [`TransportError`]. /// /// # Arguments /// -/// * `tcp_stream`: A stream that yields `Result` items, representing incoming +/// * `tcp_stream`: A stream that yields `Result` items, representing incoming /// TCP connections or errors. -/// * `tls`: A `TlsAcceptor` used to perform the TLS handshake on each TCP connection. +/// * `tls`: A [`TlsAcceptor`] used to perform the TLS handshake on each TCP connection. /// /// # Returns /// -/// A new `Stream` that yields `Result, Error>` items. -/// Each item is either a successfully established TLS connection or an error. +/// A new [`Stream`] that yields `Result, TransportError>` +/// items. Each item is either a successfully established TLS connection or an error. /// /// # Error Handling /// -/// - If the input `tcp_stream` yields an error, that error is propagated. -/// - If the TLS handshake fails, the error is wrapped in the crate's `Error` type. +/// - TCP connection errors from the input stream are passed through the `handle_accept_error` function. +/// - TLS handshake errors are converted to [`TransportError`] and passed through [`handle_accept_error`] +/// - Non-fatal errors result in skipping the current connection attempt and continuing to the next. +/// - Fatal errors are propagated, potentially leading to stream termination. +/// +/// # Examples +/// +/// ```rust,no_run +/// use std::net::SocketAddr; +/// use tokio_stream::wrappers::TcpListenerStream; +/// use tokio::net::TcpListener; +/// use tokio_rustls::TlsAcceptor; +/// use std::sync::Arc;/// +/// use hyper_server::{serve_tcp_incoming, serve_tls_incoming}; +/// +/// async fn run_tls_server(tls_config: Arc) { +/// let addr = SocketAddr::from(([127, 0, 0, 1], 8443)); +/// let listener = TcpListener::bind("127.0.0.1:443").await.unwrap(); +/// let tcp_stream = TcpListenerStream::new(listener); +/// let tls_acceptor = TlsAcceptor::from(tls_config); +/// +/// let tcp_incoming = serve_tcp_incoming(tcp_stream); +/// let tls_stream = serve_tls_incoming(tcp_incoming, tls_acceptor); +/// +/// // Use the tls_stream for further processing... +/// } +/// ``` #[inline] -pub fn serve_tls_incoming( - tcp_stream: impl Stream>, +pub fn serve_tls_incoming( + tcp_stream: impl Stream> + Send + 'static, tls: TlsAcceptor, -) -> impl Stream, Error>> +) -> impl Stream, TransportError>> where IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IE: Into + Send + 'static, { - // Transform each item in the TCP stream into a TLS stream - tcp_stream.then(move |result| { - // Clone the TLS acceptor for each connection - // This is necessary because the acceptor is moved into the async block - let tls = tls.clone(); + async_stream::stream! { + // Pin the TCP stream to the stack so that it's available and not moved in the loop + let mut tcp_stream = pin!(tcp_stream); - async move { + while let Some(result) = tcp_stream.next().await { match result { - // If the TCP connection was successfully established Ok(io) => { - // Attempt to perform the TLS handshake - // If successful, return the TLS stream; otherwise, wrap the error - tls.accept(io).await.map_err(Error::from) + // Attempt to perform the TLS handshake on the accepted TCP connection + match tls.accept(io).await { + Ok(tls_stream) => { + // Successful TLS handshake, yield the encrypted stream + yield Ok(tls_stream) + }, + Err(e) => { + // Handle TLS handshake errors + // Convert the rustls error to a TransportError for consistent error handling + let transport_error = >::into(e); + match handle_accept_error(transport_error) { + ControlFlow::Continue(()) => { + // Non-fatal error, skip this connection and continue to the next + continue; + }, + ControlFlow::Break(e) => { + // Fatal error, yield the error and potentially end the stream + yield Err(e) + } + } + } + } + } + Err(e) => { + // Handle TCP connection errors + // These errors are from the underlying TCP stream and are already `TransportError`s + match handle_accept_error(e.into()) { + ControlFlow::Continue(()) => { + // Non-fatal error, skip this connection and continue to the next + continue; + }, + ControlFlow::Break(e) => { + // Fatal error, yield the error and potentially end the stream + yield Err(e) + } + } } - // If there was an error establishing the TCP connection, propagate it - Err(e) => Err(e), } } - }) + } } /// Load the public certificate from a file. @@ -115,12 +176,20 @@ mod tests { use rustls::{ClientConfig, ServerConfig}; use std::net::SocketAddr; use std::sync::Arc; - use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::TlsAcceptor; use tokio_stream::wrappers::TcpListenerStream; use tracing::{debug, error, info, warn}; + fn init_crypto_provider() { + match rustls::crypto::aws_lc_rs::default_provider().install_default() { + Ok(_) => debug!("Default crypto provider installed successfully"), + Err(_) => { + // Crypto provider already installed + } + } + } + // Helper function to create a TLS acceptor for testing async fn create_test_tls_acceptor() -> io::Result { debug!("Creating test TLS acceptor"); @@ -140,6 +209,7 @@ mod tests { #[tokio::test] async fn test_tls_incoming_success() -> Result<(), Box> { + init_crypto_provider(); let _guard = tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) .try_init(); @@ -159,7 +229,7 @@ mod tests { // Spawn the server task let server_task = tokio::spawn(async move { debug!("Server task started"); - let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let mut tls_stream = Box::pin(serve_tls_incoming(tcp_incoming, tls_acceptor)); let result = tls_stream.next().await; debug!("Server received connection: {:?}", result.is_some()); result @@ -197,6 +267,7 @@ mod tests { #[tokio::test] async fn test_tls_incoming_invalid_cert() -> Result<(), Box> { + init_crypto_provider(); let _guard = tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) .try_init(); @@ -217,17 +288,21 @@ mod tests { .with_no_client_auth() .with_single_cert(invalid_cert, key); - match config_result { - Ok(_) => warn!("ServerConfig creation unexpectedly succeeded with invalid cert"), - Err(e) => info!("ServerConfig creation failed as expected: {}", e), - } + assert!( + config_result.is_err(), + "ServerConfig creation should fail with invalid cert" + ); + info!( + "ServerConfig creation failed as expected: {}", + config_result.unwrap_err() + ); - // Use a valid certificate for the server to allow the test to continue + // Now test with a valid certificate let valid_certs = load_certs("examples/sample.pem")?; let valid_key = load_private_key("examples/sample.rsa")?; let config = ServerConfig::builder() .with_no_client_auth() - .with_single_cert(valid_certs, valid_key) + .with_single_cert(valid_certs.clone(), valid_key) .expect("ServerConfig creation should succeed with valid cert"); let tls_acceptor = TlsAcceptor::from(Arc::new(config)); @@ -235,13 +310,10 @@ mod tests { // Use serve_tcp_incoming to handle TCP connections let tcp_incoming = serve_tcp_incoming(incoming); - // Spawn the server task + // Spawn the server task with a timeout let server_task = tokio::spawn(async move { - debug!("Server task started"); - let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); - let result = tls_stream.next().await; - debug!("Server received connection: {:?}", result.is_some()); - result + let mut tls_stream = Box::pin(serve_tls_incoming(tcp_incoming, tls_acceptor)); + tokio::time::timeout(std::time::Duration::from_millis(10), tls_stream.next()).await }); // Connect to the server with a TLS client that doesn't trust the server's certificate @@ -257,27 +329,40 @@ mod tests { // This connection should fail due to certificate verification let client_result = connector.connect(domain, tcp_stream).await; - match &client_result { - Ok(_) => warn!("Client connection succeeded unexpectedly"), - Err(e) => info!("Client connection failed as expected: {}", e), - } - assert!(client_result.is_err()); + assert!( + client_result.is_err(), + "Client connection should fail due to untrusted certificate" + ); + info!( + "Client connection failed as expected: {}", + client_result.unwrap_err() + ); - // The server should not encounter an error, but the connection should not be established - let server_result = server_task - .await? - .ok_or("Server task completed without result")?; - match &server_result { - Ok(_) => warn!("Server accepted connection unexpectedly"), - Err(e) => info!("Server did not establish connection as expected: {}", e), + // Wait for the server task to complete or timeout + let server_result = server_task.await?; + + match server_result { + Ok(Some(Ok(_))) => { + warn!("Server accepted connection unexpectedly"); + panic!("Server should not establish connection"); + } + Ok(Some(Err(e))) => { + info!("Server did not establish connection as expected: {}", e); + } + Ok(None) => { + info!("Server timed out waiting for connection, as expected"); + } + Err(e) => { + info!("Server task timed out: {}", e); + } } - assert!(server_result.is_err()); Ok(()) } #[tokio::test] async fn test_tls_incoming_client_hello_timeout() -> Result<(), Box> { + init_crypto_provider(); let _guard = tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) .try_init(); @@ -297,7 +382,7 @@ mod tests { // Spawn the server task let server_task = tokio::spawn(async move { debug!("Server task started"); - let mut tls_stream = Box::pin(tls_incoming(tcp_incoming, tls_acceptor)); + let mut tls_stream = Box::pin(serve_tls_incoming(tcp_incoming, tls_acceptor)); let result = tokio::time::timeout(std::time::Duration::from_millis(10), tls_stream.next()).await; debug!("Server task completed with result: {:?}", result.is_err()); @@ -319,64 +404,4 @@ mod tests { Ok(()) } - - #[tokio::test] - async fn test_load_certs() -> io::Result<()> { - let _guard = tracing_subscriber::fmt() - .with_max_level(tracing::Level::DEBUG) - .try_init(); - - info!("Starting test_load_certs"); - let certs = load_certs("examples/sample.pem")?; - debug!("Loaded {} certificates", certs.len()); - assert!(!certs.is_empty(), "Certificate file should not be empty"); - Ok(()) - } - - #[tokio::test] - async fn test_load_private_key() -> io::Result<()> { - let _guard = tracing_subscriber::fmt() - .with_max_level(tracing::Level::DEBUG) - .try_init(); - - info!("Starting test_load_private_key"); - let key = load_private_key("examples/sample.rsa")?; - debug!("Loaded private key, length: {}", key.secret_der().len()); - assert!( - !key.secret_der().is_empty(), - "Private key should not be empty" - ); - Ok(()) - } - - // Simulating the tls_incoming function for testing purposes - // Replace this with your actual implementation - fn tls_incoming( - incoming: impl Stream> + Send + 'static, - tls_acceptor: TlsAcceptor, - ) -> impl Stream, Error>> + Send + 'static - where - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, - { - Box::pin(incoming.then(move |result| { - let tls_acceptor = tls_acceptor.clone(); - async move { - match result { - Ok(io) => { - debug!("Accepting TLS connection"); - let accept_result = tls_acceptor.accept(io).await.map_err(Error::from); - match &accept_result { - Ok(_) => debug!("TLS connection accepted successfully"), - Err(e) => warn!("Failed to accept TLS connection: {}", e), - } - accept_result - } - Err(e) => { - warn!("Error in incoming connection: {}", e); - Err(e) - } - } - } - })) - } }