From 6d93b4251766d97120b96ecee6d198b3406da7da Mon Sep 17 00:00:00 2001 From: megsdevs <57240925+megsdevs@users.noreply.github.com> Date: Sat, 14 Oct 2023 01:30:49 +0100 Subject: [PATCH] Proxy protocol support (#3) Add proxy protocol support. --------- Co-authored-by: Alcibiades <89996683+0xAlcibiades@users.noreply.github.com> Co-authored-by: Dave Belvedere <90095005+dbelv@users.noreply.github.com> Co-authored-by: Alcibiades Athens --- .gitignore | 2 +- Cargo.toml | 4 + src/lib.rs | 4 + src/proxy_protocol/future.rs | 143 ++++++ src/proxy_protocol/mod.rs | 841 +++++++++++++++++++++++++++++++++++ src/server.rs | 49 +- src/tls_openssl/mod.rs | 9 +- src/tls_rustls/mod.rs | 10 +- 8 files changed, 1047 insertions(+), 15 deletions(-) create mode 100644 src/proxy_protocol/future.rs create mode 100644 src/proxy_protocol/mod.rs diff --git a/.gitignore b/.gitignore index 284b41a..9eb7b46 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,4 @@ settings.yaml lcov.info # Ignore cargo lock for library -Cargo.lock +Cargo.lock \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 7dcffca..db6b24d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ version = "0.5.3" default = [] tls-rustls = ["arc-swap", "pin-project-lite", "rustls", "rustls-pemfile", "tokio/fs", "tokio/time", "tokio-rustls"] tls-openssl = ["openssl", "tokio-openssl", "pin-project-lite"] +proxy-protocol = ["ppp", "pin-project-lite"] [dependencies] @@ -37,6 +38,9 @@ tokio-openssl = { version = "0.6", optional = true } tokio-rustls = { version = "0.24", optional = true } tower-service = "0.3" +## proxy-protocol +ppp = { version = "2.2.0", optional = true } + [dev-dependencies] axum = "0.6" hyper = { version = "0.14", features = ["full"] } diff --git a/src/lib.rs b/src/lib.rs index eb8e417..932c092 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,3 +124,7 @@ pub mod tls_openssl; #[doc(inline)] #[cfg(feature = "tls-openssl")] pub use self::tls_openssl::bind_openssl; + +#[cfg(feature = "proxy-protocol")] +#[cfg_attr(docsrs, doc(cfg(feature = "proxy_protocol")))] +pub mod proxy_protocol; diff --git a/src/proxy_protocol/future.rs b/src/proxy_protocol/future.rs new file mode 100644 index 0000000..bd9ab26 --- /dev/null +++ b/src/proxy_protocol/future.rs @@ -0,0 +1,143 @@ +//! Future types for PROXY protocol support. +use crate::accept::Accept; +use crate::proxy_protocol::ForwardClientIp; +use pin_project_lite::pin_project; +use std::{ + fmt, + future::Future, + io, + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::Timeout; + +// A `pin_project` is a procedural macro used for safe field projection in conjunction +// with the Rust Pin API, which guarantees that certain types will not move in memory. +pin_project! { + /// This struct represents the future for the ProxyProtocolAcceptor. + /// The generic types are: + /// F: The future type. + /// A: The type that implements the Accept trait. + /// I: The IO type that supports both AsyncRead and AsyncWrite. + /// S: The service type. + pub struct ProxyProtocolAcceptorFuture + where + A: Accept, + { + #[pin] + inner: AcceptFuture, + } +} + +impl ProxyProtocolAcceptorFuture +where + A: Accept, + I: AsyncRead + AsyncWrite + Unpin, +{ + // Constructor for creating a new ProxyProtocolAcceptorFuture. + pub(crate) fn new(future: Timeout, acceptor: A, service: S) -> Self { + let inner = AcceptFuture::ReadHeader { + future, + acceptor, + service: Some(service), + }; + Self { inner } + } +} + +// Implement Debug trait for ProxyProtocolAcceptorFuture to allow +// debugging and logging. +impl fmt::Debug for ProxyProtocolAcceptorFuture +where + A: Accept, + I: AsyncRead + AsyncWrite + Unpin, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProxyProtocolAcceptorFuture").finish() + } +} + +pin_project! { + // AcceptFuture represents the internal states of ProxyProtocolAcceptorFuture. + // It can either be waiting to read the header or forward the client IP. + #[project = AcceptFutureProj] + enum AcceptFuture + where + A: Accept, + { + ReadHeader { + #[pin] + future: Timeout, + acceptor: A, + service: Option, + }, + ForwardIp { + #[pin] + future: A::Future, + client_address: Option, + }, + } +} + +impl Future for ProxyProtocolAcceptorFuture +where + A: Accept, + I: AsyncRead + AsyncWrite + Unpin, + // Future whose output is a result with either a tuple of stream and optional address, + // or an io::Error. + F: Future), io::Error>>, +{ + type Output = io::Result<(A::Stream, ForwardClientIp)>; + + // The main poll function that drives the future towards completion. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + // Check the current state of the inner future. + match this.inner.as_mut().project() { + AcceptFutureProj::ReadHeader { + future, + acceptor, + service, + } => match future.poll(cx) { + Poll::Ready(Ok(Ok((stream, client_address)))) => { + let service = service.take().expect("future polled after ready"); + let future = acceptor.accept(stream, service); + + // Transition to the ForwardIp state after successfully reading the header. + this.inner.set(AcceptFuture::ForwardIp { + future, + client_address, + }); + } + Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(Err(timeout)) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::TimedOut, timeout))) + } + Poll::Pending => return Poll::Pending, + }, + AcceptFutureProj::ForwardIp { + future, + client_address, + } => { + return match future.poll(cx) { + Poll::Ready(Ok((stream, service))) => { + let service = ForwardClientIp { + inner: service, + client_address: *client_address, + }; + + // Return the successfully processed stream and service. + Poll::Ready(Ok((stream, service))) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; + } + } + } + } +} diff --git a/src/proxy_protocol/mod.rs b/src/proxy_protocol/mod.rs new file mode 100644 index 0000000..0906eed --- /dev/null +++ b/src/proxy_protocol/mod.rs @@ -0,0 +1,841 @@ +//! This feature allows the `hyper_server` to be used behind a layer 4 load balancer whilst the proxy +//! protocol is enabled to preserve the client IP address and port. +//! See The PROXY protocol spec for more details: . +//! +//! Any client address found in the proxy protocol header is forwarded on in the HTTP `forwarded` +//! header to be accessible by the rest server. +//! +//! Note: if you are setting a custom acceptor, `enable_proxy_protocol` must be called after this is set. +//! It is best to use directly before calling `serve` when the inner acceptor is already configured. +//! `ProxyProtocolAcceptor` wraps the initial acceptor, so the proxy header is removed from the +//! beginning of the stream before the messages are forwarded on. +//! +//! # Example +//! +//! ```rust,no_run +//! use axum::{routing::get, Router}; +//! use std::net::SocketAddr; +//! use std::time::Duration; +//! +//! #[tokio::main] +//! async fn main() { +//! let app = Router::new().route("/", get(|| async { "Hello, world!" })); +//! +//! let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); +//! println!("listening on {}", addr); +//! +//! // Can configure if you want different from the default of 5 seconds, +//! // otherwise passing `None` will use the default. +//! let proxy_header_timeout = Some(Duration::from_secs(2)); +//! +//! hyper_server::bind(addr) +//! .enable_proxy_protocol(proxy_header_timeout) +//! .serve(app.into_make_service()) +//! .await +//! .unwrap(); +//! } +//! ``` +use crate::accept::Accept; +use std::{ + fmt, + future::Future, + io, + net::{IpAddr, SocketAddr}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use http::HeaderValue; +use http::Request; +use ppp::{v1, v2, HeaderResult}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite}, + time::timeout, +}; +use tower_service::Service; + +pub(crate) mod future; +use self::future::ProxyProtocolAcceptorFuture; + +/// The length of a v1 header in bytes. +const V1_PREFIX_LEN: usize = 5; +/// The maximum length of a v1 header in bytes. +const V1_MAX_LENGTH: usize = 107; +/// The terminator bytes of a v1 header. +const V1_TERMINATOR: &[u8] = b"\r\n"; +/// The prefix length of a v2 header in bytes. +const V2_PREFIX_LEN: usize = 12; +/// The minimum length of a v2 header in bytes. +const V2_MINIMUM_LEN: usize = 16; +/// The index of the start of the big-endian u16 length in the v2 header. +const V2_LENGTH_INDEX: usize = 14; +/// The length of the read buffer used to read the PROXY protocol header. +const READ_BUFFER_LEN: usize = 512; + +pub(crate) async fn read_proxy_header( + mut stream: I, +) -> Result<(I, Option), io::Error> +where + I: AsyncRead + Unpin, +{ + // Mutable buffer for storing stream data + let mut buffer = [0; READ_BUFFER_LEN]; + // Dynamic in case v2 header is too long + let mut dynamic_buffer = None; + + // Read prefix to check for v1, v2, or kill + stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?; + + if &buffer[..V1_PREFIX_LEN] == v1::PROTOCOL_PREFIX.as_bytes() { + read_v1_header(&mut stream, &mut buffer).await?; + } else { + stream + .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN]) + .await?; + if &buffer[..V2_PREFIX_LEN] == v2::PROTOCOL_PREFIX { + dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "No valid Proxy Protocol header detected", + )); + } + } + + // Choose which buffer to parse + let buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]); + + // Parse the header + let header = HeaderResult::parse(buffer); + match header { + HeaderResult::V1(Ok(header)) => { + let client_address = match header.addresses { + v1::Addresses::Tcp4(ip) => { + SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port) + } + v1::Addresses::Tcp6(ip) => { + SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port) + } + v1::Addresses::Unknown => { + // Return client address as `None` so that "unknown" is used in the http header + return Ok((stream, None)); + } + }; + + Ok((stream, Some(client_address))) + } + HeaderResult::V2(Ok(header)) => { + let client_address = match header.addresses { + v2::Addresses::IPv4(ip) => { + SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port) + } + v2::Addresses::IPv6(ip) => { + SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port) + } + v2::Addresses::Unix(unix) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Unix socket addresses are not supported. Addresses: {:?}", + unix + ), + )); + } + v2::Addresses::Unspecified => { + // Return client address as `None` so that "unknown" is used in the http header + return Ok((stream, None)); + } + }; + + Ok((stream, Some(client_address))) + } + HeaderResult::V1(Err(_error)) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "No valid V1 Proxy Protocol header received", + )), + HeaderResult::V2(Err(_error)) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "No valid V2 Proxy Protocol header received", + )), + } +} + +async fn read_v2_header( + mut stream: I, + buffer: &mut [u8; READ_BUFFER_LEN], +) -> Result>, io::Error> +where + I: AsyncRead + Unpin, +{ + let length = + u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize; + let full_length = V2_MINIMUM_LEN + length; + + // Switch to dynamic buffer if header is too long; v2 has no maximum length + if full_length > READ_BUFFER_LEN { + let mut dynamic_buffer = Vec::with_capacity(full_length); + dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]); + + // Read the remaining header length + stream + .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(Some(dynamic_buffer)) + } else { + // Read the remaining header length + stream + .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length]) + .await?; + + Ok(None) + } +} + +async fn read_v1_header( + mut stream: I, + buffer: &mut [u8; READ_BUFFER_LEN], +) -> Result<(), io::Error> +where + I: AsyncRead + Unpin, +{ + // read one byte at a time until terminator found + let mut end_found = false; + for i in V1_PREFIX_LEN..V1_MAX_LENGTH { + buffer[i] = stream.read_u8().await?; + + if [buffer[i - 1], buffer[i]] == V1_TERMINATOR { + end_found = true; + break; + } + } + if !end_found { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "No valid Proxy Protocol header detected", + )); + } + + Ok(()) +} + +/// Middleware for adding client IP address to the request `forwarded` header. +/// see spec: +#[derive(Debug, Clone)] +pub struct ForwardClientIp { + inner: S, + client_address: Option, +} + +impl Service> for ForwardClientIp +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + // The full socket address is available in the proxy header, hence why we include port + let mut forwarded_string = match self.client_address { + Some(socket_addr) => match socket_addr { + SocketAddr::V4(addr) => { + format!("for={}:{}", addr.ip(), addr.port()) + } + SocketAddr::V6(addr) => { + format!("for=\"[{}]:{}\"", addr.ip(), addr.port()) + } + }, + None => "for=unknown".to_string(), + }; + + if let Some(existing_value) = req.headers_mut().get("Forwarded") { + forwarded_string = format!( + "{}, {}", + existing_value.to_str().unwrap_or(""), + forwarded_string + ); + } + + if let Ok(header_value) = HeaderValue::from_str(&forwarded_string) { + req.headers_mut().insert("Forwarded", header_value); + } + + self.inner.call(req) + } +} + +/// Acceptor wrapper for receiving Proxy Protocol headers. +#[derive(Clone)] +pub struct ProxyProtocolAcceptor { + inner: A, + parsing_timeout: Duration, +} + +impl ProxyProtocolAcceptor { + /// Create a new proxy protocol acceptor from an initial acceptor. + /// This is compatible with tls acceptors. + pub fn new(inner: A) -> Self { + #[cfg(not(test))] + let parsing_timeout = Duration::from_secs(5); + + // Don't force tests to wait too long. + #[cfg(test)] + let parsing_timeout = Duration::from_secs(1); + + Self { + inner, + parsing_timeout, + } + } + + /// Override the default Proxy Header parsing timeout. + pub fn parsing_timeout(mut self, val: Duration) -> Self { + self.parsing_timeout = val; + self + } +} + +impl ProxyProtocolAcceptor { + /// Overwrite inner acceptor. + pub fn acceptor(self, acceptor: Acceptor) -> ProxyProtocolAcceptor { + ProxyProtocolAcceptor { + inner: acceptor, + parsing_timeout: self.parsing_timeout, + } + } +} + +impl Accept for ProxyProtocolAcceptor +where + A: Accept + Clone, + A::Stream: AsyncRead + AsyncWrite + Unpin, + I: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = A::Stream; + type Service = ForwardClientIp; + type Future = ProxyProtocolAcceptorFuture< + Pin), io::Error>> + Send>>, + A, + I, + S, + >; + + fn accept(&self, stream: I, service: S) -> Self::Future { + let future = Box::pin(read_proxy_header(stream)); + + ProxyProtocolAcceptorFuture::new( + timeout(self.parsing_timeout, future), + self.inner.clone(), + service, + ) + } +} + +impl fmt::Debug for ProxyProtocolAcceptor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProxyProtocolAcceptor").finish() + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "tls-openssl")] + use crate::tls_openssl::{ + self, + tests::{dns_name as openssl_dns_name, tls_connector as openssl_connector}, + OpenSSLConfig, + }; + #[cfg(feature = "tls-rustls")] + use crate::tls_rustls::{ + self, + tests::{dns_name as rustls_dns_name, tls_connector as rustls_connector}, + RustlsConfig, + }; + use crate::{handle::Handle, server::Server}; + use axum::http::Response; + use axum::{routing::get, Router}; + use bytes::Bytes; + use http::{response, Request}; + use hyper::{ + client::conn::{handshake, SendRequest}, + Body, + }; + use ppp::v2::{Builder, Command, Protocol, Type, Version}; + use std::{io, net::SocketAddr, time::Duration}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::{ + net::{TcpListener, TcpStream}, + task::JoinHandle, + time::timeout, + }; + use tower::{Service, ServiceExt}; + + #[tokio::test] + async fn start_and_request() { + let (_handle, _server_task, server_addr) = start_server(true).await; + + let addr = start_proxy(server_addr, ProxyVersion::V2) + .await + .expect("Failed to start proxy"); + + let (mut client, _conn, _client_addr) = connect(addr).await; + + let (_parts, body) = send_empty_request(&mut client).await; + + assert_eq!(body.as_ref(), b"Hello, world!"); + } + + #[tokio::test] + async fn server_receives_client_address() { + let (_handle, _server_task, server_addr) = start_server(true).await; + + let addr = start_proxy(server_addr, ProxyVersion::V2) + .await + .expect("Failed to start proxy"); + + let (mut client, _conn, client_addr) = connect(addr).await; + + let (parts, body) = send_empty_request(&mut client).await; + + // Check for the Forwarded header + let forwarded_header = parts + .headers + .get("Forwarded") + .expect("No Forwarded header present") + .to_str() + .expect("Failed to convert Forwarded header to str"); + + assert!(forwarded_header.contains(&format!("for={}", client_addr))); + assert_eq!(body.as_ref(), b"Hello, world!"); + } + + #[tokio::test] + async fn server_receives_client_address_v1() { + let (_handle, _server_task, server_addr) = start_server(true).await; + + let addr = start_proxy(server_addr, ProxyVersion::V1) + .await + .expect("Failed to start proxy"); + + let (mut client, _conn, client_addr) = connect(addr).await; + + let (parts, body) = send_empty_request(&mut client).await; + + // Check for the Forwarded header + let forwarded_header = parts + .headers + .get("Forwarded") + .expect("No Forwarded header present") + .to_str() + .expect("Failed to convert Forwarded header to str"); + + assert!(forwarded_header.contains(&format!("for={}", client_addr))); + assert_eq!(body.as_ref(), b"Hello, world!"); + } + + #[cfg(feature = "tls-rustls")] + #[tokio::test] + async fn rustls_server_receives_client_address() { + let (_handle, _server_task, server_addr) = start_rustls_server().await; + + let addr = start_proxy(server_addr, ProxyVersion::V2) + .await + .expect("Failed to start proxy"); + + let (mut client, _conn, client_addr) = rustls_connect(addr).await; + + let (parts, body) = send_empty_request(&mut client).await; + + // Check for the Forwarded header + let forwarded_header = parts + .headers + .get("Forwarded") + .expect("No Forwarded header present") + .to_str() + .expect("Failed to convert Forwarded header to str"); + + assert!(forwarded_header.contains(&format!("for={}", client_addr))); + assert_eq!(body.as_ref(), b"Hello, world!"); + } + + #[cfg(feature = "tls-openssl")] + #[tokio::test] + async fn openssl_server_receives_client_address() { + let (_handle, _server_task, server_addr) = start_openssl_server().await; + + let addr = start_proxy(server_addr, ProxyVersion::V2) + .await + .expect("Failed to start proxy"); + + let (mut client, _conn, client_addr) = openssl_connect(addr).await; + + let (parts, body) = send_empty_request(&mut client).await; + + // Check for the Forwarded header + let forwarded_header = parts + .headers + .get("Forwarded") + .expect("No Forwarded header present") + .to_str() + .expect("Failed to convert Forwarded header to str"); + + assert!(forwarded_header.contains(&format!("for={}", client_addr))); + assert_eq!(body.as_ref(), b"Hello, world!"); + } + + #[tokio::test] + async fn not_parsing_when_header_present_fails() { + // Start the server with proxy protocol disabled + let (_handle, _server_task, server_addr) = start_server(false).await; + + // Start the proxy + let addr = start_proxy(server_addr, ProxyVersion::V2) + .await + .expect("Failed to start proxy"); + + // Connect to the proxy + let (mut client, _conn, _client_addr) = connect(addr).await; + + // Send a request to the proxy + match client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + { + // TODO(This should fail when there is no proxy protocol support, perhaps) + Ok(_o) => { + //dbg!(_o); + //() + } + Err(e) => { + if e.is_incomplete_message() { + } else { + panic!("Received unexpected error"); + } + } + } + } + + #[tokio::test] + async fn parsing_when_header_not_present_fails() { + let (_handle, _server_task, server_addr) = start_server(true).await; + + let addr = start_proxy(server_addr, ProxyVersion::None) + .await + .expect("Failed to start proxy"); + + let (mut client, _conn, _client_addr) = connect(addr).await; + + match client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + { + Ok(_) => panic!("Should have failed"), + Err(e) => { + if e.is_incomplete_message() { + } else { + panic!("Received unexpected error"); + } + } + } + } + + async fn forward_ip_handler(req: Request) -> Response { + let mut response = Response::new(Body::from("Hello, world!")); + + if let Some(header_value) = req.headers().get("Forwarded") { + response + .headers_mut() + .insert("Forwarded", header_value.clone()); + } + + response + } + + async fn start_server( + parse_proxy_header: bool, + ) -> (Handle, JoinHandle>, SocketAddr) { + let handle = Handle::new(); + + let server_handle = handle.clone(); + let server_task = tokio::spawn(async move { + let app = Router::new().route("/", get(forward_ip_handler)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + + if parse_proxy_header { + Server::bind(addr) + .handle(server_handle) + .enable_proxy_protocol(None) + .serve(app.into_make_service()) + .await + } else { + Server::bind(addr) + .handle(server_handle) + .serve(app.into_make_service()) + .await + } + }); + + let addr = handle.listening().await.unwrap(); + + (handle, server_task, addr) + } + + #[cfg(feature = "tls-rustls")] + async fn start_rustls_server() -> (Handle, JoinHandle>, SocketAddr) { + let handle = Handle::new(); + + let server_handle = handle.clone(); + let server_task = tokio::spawn(async move { + let app = Router::new().route("/", get(forward_ip_handler)); + + let config = RustlsConfig::from_pem_file( + "examples/self-signed-certs/cert.pem", + "examples/self-signed-certs/key.pem", + ) + .await?; + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + + tls_rustls::bind_rustls(addr, config) + .handle(server_handle) + .enable_proxy_protocol(None) + .serve(app.into_make_service()) + .await + }); + + let addr = handle.listening().await.unwrap(); + + (handle, server_task, addr) + } + + #[cfg(feature = "tls-openssl")] + async fn start_openssl_server() -> (Handle, JoinHandle>, SocketAddr) { + let handle = Handle::new(); + + let server_handle = handle.clone(); + let server_task = tokio::spawn(async move { + let app = Router::new().route("/", get(forward_ip_handler)); + + let config = OpenSSLConfig::from_pem_file( + "examples/self-signed-certs/cert.pem", + "examples/self-signed-certs/key.pem", + ) + .unwrap(); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + + tls_openssl::bind_openssl(addr, config) + .handle(server_handle) + .enable_proxy_protocol(None) + .serve(app.into_make_service()) + .await + }); + + let addr = handle.listening().await.unwrap(); + + (handle, server_task, addr) + } + + #[derive(Debug, Clone, Copy)] + enum ProxyVersion { + V1, + V2, + None, + } + + async fn start_proxy( + server_address: SocketAddr, + proxy_version: ProxyVersion, + ) -> Result> { + let proxy_address = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(proxy_address).await?; + let proxy_address = listener.local_addr()?; + + let _proxy_task = tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((client_stream, _)) => { + tokio::spawn(async move { + if let Err(e) = + handle_conn(client_stream, server_address, proxy_version).await + { + println!("Error handling connection: {:?}", e); + } + }); + } + Err(e) => println!("Failed to accept a connection: {:?}", e), + } + } + }); + + Ok(proxy_address) + } + + async fn handle_conn( + mut client_stream: TcpStream, + server_address: SocketAddr, + proxy_version: ProxyVersion, + ) -> io::Result<()> { + let client_address = client_stream.peer_addr()?; // Get the address before splitting + let mut server_stream = TcpStream::connect(server_address).await?; + let server_address = server_stream.peer_addr()?; // Get the address before splitting + + let (mut client_read, mut client_write) = client_stream.split(); + let (mut server_read, mut server_write) = server_stream.split(); + + send_proxy_header( + &mut server_write, + client_address, + server_address, + proxy_version, + ) + .await?; + + let duration = Duration::from_secs(1); + let client_to_server = async { + match timeout(duration, transfer(&mut client_read, &mut server_write)).await { + Ok(result) => result, + Err(_) => Err(io::Error::new( + io::ErrorKind::TimedOut, + "Client to Server transfer timed out", + )), + } + }; + + let server_to_client = async { + match timeout(duration, transfer(&mut server_read, &mut client_write)).await { + Ok(result) => result, + Err(_) => Err(io::Error::new( + io::ErrorKind::TimedOut, + "Server to Client transfer timed out", + )), + } + }; + + let _ = tokio::try_join!(client_to_server, server_to_client); + + Ok(()) + } + + async fn transfer( + read_stream: &mut (impl AsyncReadExt + Unpin), + write_stream: &mut (impl AsyncWriteExt + Unpin), + ) -> io::Result<()> { + let mut buf = [0; 4096]; + loop { + let n = read_stream.read(&mut buf).await?; + if n == 0 { + break; // EOF + } + write_stream.write_all(&buf[..n]).await?; + } + Ok(()) + } + + async fn send_proxy_header( + write_stream: &mut (impl AsyncWriteExt + Unpin), + client_address: SocketAddr, + server_address: SocketAddr, + proxy_version: ProxyVersion, + ) -> io::Result<()> { + match proxy_version { + ProxyVersion::V1 => { + let header = ppp::v1::Addresses::from((client_address, server_address)).to_string(); + + for byte in header.as_bytes() { + write_stream.write_all(&[*byte]).await?; + } + } + ProxyVersion::V2 => { + let mut header = Builder::with_addresses( + // Declare header as mutable + Version::Two | Command::Proxy, + Protocol::Stream, + (client_address, server_address), + ) + .write_tlv(Type::NoOp, b"Hello, World!")? + .build()?; + + for byte in header.drain(..) { + write_stream.write_all(&[byte]).await?; + } + } + ProxyVersion::None => {} + } + + Ok(()) + } + + async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { + let stream = TcpStream::connect(addr).await.unwrap(); + let client_addr = stream.local_addr().unwrap(); + + let (send_request, connection) = handshake(stream).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = connection.await; + }); + + (send_request, task, client_addr) + } + + #[cfg(feature = "tls-rustls")] + async fn rustls_connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { + let stream = TcpStream::connect(addr).await.unwrap(); + let client_addr = stream.local_addr().unwrap(); + let tls_stream = rustls_connector() + .connect(rustls_dns_name(), stream) + .await + .unwrap(); + + let (send_request, connection) = handshake(tls_stream).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = connection.await; + }); + + (send_request, task, client_addr) + } + + #[cfg(feature = "tls-openssl")] + async fn openssl_connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>, SocketAddr) { + let stream = TcpStream::connect(addr).await.unwrap(); + let client_addr = stream.local_addr().unwrap(); + let tls_stream = openssl_connector(openssl_dns_name(), stream).await; + + let (send_request, connection) = handshake(tls_stream).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = connection.await; + }); + + (send_request, task, client_addr) + } + + async fn send_empty_request(client: &mut SendRequest) -> (response::Parts, Bytes) { + let (parts, body) = client + .ready() + .await + .unwrap() + .call(Request::new(Body::empty())) + .await + .unwrap() + .into_parts(); + let body = hyper::body::to_bytes(body).await.unwrap(); + + (parts, body) + } +} diff --git a/src/server.rs b/src/server.rs index 2401fe1..6db138a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,5 @@ -// Necessary module imports +#[cfg(feature = "proxy-protocol")] +use crate::proxy_protocol::ProxyProtocolAcceptor; use crate::{ accept::{Accept, DefaultAcceptor}, addr_incoming_config::AddrIncomingConfig, @@ -12,6 +13,8 @@ use hyper::server::{ accept::Accept as HyperAccept, conn::{AddrIncoming, AddrStream}, }; +#[cfg(feature = "proxy-protocol")] +use std::time::Duration; use std::{ io::{self, ErrorKind}, net::SocketAddr, @@ -30,6 +33,8 @@ pub struct Server { addr_incoming_conf: AddrIncomingConfig, handle: Handle, http_conf: HttpConfig, + #[cfg(feature = "proxy-protocol")] + proxy_acceptor_set: bool, } /// Enum representing the ways the server can be initialized - either by binding to an address or from a standard TCP listener. @@ -61,6 +66,8 @@ impl Server { addr_incoming_conf: AddrIncomingConfig::default(), handle, http_conf: HttpConfig::default(), + #[cfg(feature = "proxy-protocol")] + proxy_acceptor_set: false, } } @@ -75,6 +82,8 @@ impl Server { addr_incoming_conf: AddrIncomingConfig::default(), handle, http_conf: HttpConfig::default(), + #[cfg(feature = "proxy-protocol")] + proxy_acceptor_set: false, } } } @@ -82,12 +91,43 @@ impl Server { impl Server { /// Replace the current acceptor with a new one. pub fn acceptor(self, acceptor: Acceptor) -> Server { + #[cfg(feature = "proxy-protocol")] + if self.proxy_acceptor_set { + panic!("Overwriting the acceptor after proxy protocol is enabled is not supported. Configure the acceptor first in the builder, then enable proxy protocol."); + } + + Server { + acceptor, + listener: self.listener, + addr_incoming_conf: self.addr_incoming_conf, + handle: self.handle, + http_conf: self.http_conf, + #[cfg(feature = "proxy-protocol")] + proxy_acceptor_set: self.proxy_acceptor_set, + } + } + + #[cfg(feature = "proxy-protocol")] + /// Enable proxy protocol header parsing. + /// Note has to be called after initial acceptor is set. + pub fn enable_proxy_protocol( + self, + parsing_timeout: Option, + ) -> Server> { + let initial_acceptor = self.acceptor; + let mut acceptor = ProxyProtocolAcceptor::new(initial_acceptor); + + if let Some(val) = parsing_timeout { + acceptor = acceptor.parsing_timeout(val); + } + Server { acceptor, listener: self.listener, addr_incoming_conf: self.addr_incoming_conf, handle: self.handle, http_conf: self.http_conf, + proxy_acceptor_set: true, } } @@ -102,6 +142,8 @@ impl Server { addr_incoming_conf: self.addr_incoming_conf, handle: self.handle, http_conf: self.http_conf, + #[cfg(feature = "proxy-protocol")] + proxy_acceptor_set: self.proxy_acceptor_set, } } @@ -402,6 +444,7 @@ mod tests { // Disconnect client. conn.abort(); + // TODO(This does not shut down gracefully) // Server task should finish soon. let server_result = timeout(Duration::from_secs(1), server_task) .await @@ -411,7 +454,6 @@ mod tests { assert!(server_result.is_ok()); } - #[ignore] #[tokio::test] async fn test_graceful_shutdown_timed() { let (handle, server_task, addr) = start_server().await; @@ -424,9 +466,6 @@ mod tests { assert_eq!(body.as_ref(), b"Hello, world!"); - // Don't disconnect client. - // conn.abort(); - // Server task should finish soon. let server_result = timeout(Duration::from_secs(1), server_task) .await diff --git a/src/tls_openssl/mod.rs b/src/tls_openssl/mod.rs index 0956612..f108517 100644 --- a/src/tls_openssl/mod.rs +++ b/src/tls_openssl/mod.rs @@ -232,7 +232,7 @@ impl fmt::Debug for OpenSSLConfig { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use crate::{ handle::Handle, tls_openssl::{self, OpenSSLConfig}, @@ -308,7 +308,6 @@ mod tests { assert!(server_result.is_ok()); } - #[ignore] #[tokio::test] async fn test_graceful_shutdown_timed() { let (handle, server_task, addr) = start_server().await; @@ -386,7 +385,8 @@ mod tests { (parts, body) } - async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream { + /// Used in `proxy-protocol` feature tests. + pub(crate) async fn tls_connector(hostname: &str, stream: TcpStream) -> SslStream { let mut tls_parms = SslConnector::builder(SslMethod::tls_client()).unwrap(); tls_parms.set_verify(SslVerifyMode::NONE); let hostname_owned = hostname.to_string(); @@ -405,7 +405,8 @@ mod tests { tls_stream } - fn dns_name() -> &'static str { + /// Used in `proxy-protocol` feature tests. + pub(crate) fn dns_name() -> &'static str { "localhost" } } diff --git a/src/tls_rustls/mod.rs b/src/tls_rustls/mod.rs index 1522cc0..af7e391 100644 --- a/src/tls_rustls/mod.rs +++ b/src/tls_rustls/mod.rs @@ -302,7 +302,7 @@ async fn config_from_pem_file( } #[cfg(test)] -mod tests { +pub(crate) mod tests { use crate::{ handle::Handle, tls_rustls::{self, RustlsConfig}, @@ -341,7 +341,6 @@ mod tests { assert_eq!(body.as_ref(), b"Hello, world!"); } - #[ignore] #[tokio::test] async fn tls_timeout() { let (handle, _server_task, addr) = start_server().await; @@ -459,7 +458,6 @@ mod tests { assert!(server_result.is_ok()); } - #[ignore] #[tokio::test] async fn test_graceful_shutdown_timed() { let (handle, server_task, addr) = start_server().await; @@ -546,7 +544,8 @@ mod tests { (parts, body) } - fn tls_connector() -> TlsConnector { + /// Used in `proxy-protocol` feature tests. + pub(crate) fn tls_connector() -> TlsConnector { struct NoVerify; impl ServerCertVerifier for NoVerify { @@ -573,7 +572,8 @@ mod tests { TlsConnector::from(Arc::new(client_config)) } - fn dns_name() -> ServerName { + /// Used in `proxy-protocol` feature tests. + pub(crate) fn dns_name() -> ServerName { ServerName::try_from("localhost").unwrap() } }