Skip to content

Commit

Permalink
feat: higher level server and tests for http
Browse files Browse the repository at this point in the history
  • Loading branch information
0xAlcibiades committed Sep 10, 2024
1 parent 42beb88 commit 82e39be
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 9 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ hyper = "1.4.1"
hyper-util = { version = "0.1.8", features = ["server", "tokio", "server-auto", "server-graceful", "service"] }
rustls = "0.23.13"
rustls-pemfile = "2.1.3"
tokio = { version = "1.40.0", features = ["net"] }
tokio = { version = "1.40.0", features = ["net", "macros"] }
tokio-rustls = "0.26.0"
tower = { version = "0.5.1", features = ["util"] }
tracing = "0.1.40"
Expand All @@ -30,5 +30,6 @@ async-stream = "0.3.5"
futures = "0.3.30"

[dev-dependencies]
tokio = { version = "1.0", features = ["rt", "net", "test-util", "macros"] }
tokio = { version = "1.0", features = ["rt", "net", "test-util"] }
tokio-test = "0.4.4"
hyper = {version = "1.4.1", features = ["client"] }
259 changes: 252 additions & 7 deletions src/http.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
use tokio_stream::StreamExt as _;
use tracing::{debug, trace};
use hyper_util::{
rt::{TokioExecutor, TokioIo, TokioTimer},
server::conn::auto::{Builder as HttpConnectionBuilder, HttpServerConnExec},
service::TowerToHyperService,
};
use bytes::Bytes;
use http::{Request, Response};
use http_body_util::BodyExt;
use hyper::{body::Incoming, service::Service as HyperService};
use pin_project::pin_project;
use std::future::pending;
use std::{
convert::Infallible,
fmt,
future::{self, poll_fn, Future},
marker::PhantomData,
net::SocketAddr,
pin::{pin, Pin},
sync::Arc,
task::{ready, Context, Poll},
time::Duration,
};
use futures::Sink;
use http_body::Body;
use hyper::body::Incoming;
use hyper::service::Service;
use hyper_util::server::conn::auto::{Builder, HttpServerConnExec};
use std::future::pending;
use std::pin::pin;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use tokio::time::sleep;
use tracing::{debug, trace};
use tokio_stream::Stream;
use tokio_stream::wrappers::TcpListenerStream;
use crate::fuse::Fuse;

/// Sleeps for a specified duration or waits indefinitely.
///
Expand Down Expand Up @@ -47,7 +70,7 @@ async fn sleep_or_pending(wait_for: Option<Duration>) {
pub(crate) async fn serve_http_connection<B, IO, S, E>(
hyper_io: IO,
hyper_service: S,
builder: Builder<E>,
builder: HttpConnectionBuilder<E>,
mut watcher: Option<tokio::sync::watch::Receiver<()>>,
max_connection_age: Option<Duration>,
) where
Expand Down Expand Up @@ -106,3 +129,225 @@ pub(crate) async fn serve_http_connection<B, IO, S, E>(
trace!("HTTP connection closed");
});
}

pub(crate) async fn serve_http_with_shutdown<S, I, F, IO, IE, ResBody>(
service: S,
incoming: I,
signal: Option<F>,
) -> Result<(), super::Error>
where
F: Future<Output = ()>,
I: Stream<Item = Result<IO, IE>> + Send + 'static,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::Error> + Send + 'static,
S: Service<Request<Incoming>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
ResBody: Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<crate::Error> + Send + Sync,
{
let incoming = crate::tcp::serve_tcp_incoming(
incoming
);

let server = {
let mut builder = HttpConnectionBuilder::new(TokioExecutor::new());
builder
};

let (signal_tx, signal_rx) = tokio::sync::watch::channel(());
let signal_tx = Arc::new(signal_tx);

let graceful = signal.is_some();
let mut sig = pin!(Fuse { inner: signal });
let mut incoming = pin!(incoming);

loop {
tokio::select! {
_ = &mut sig => {
trace!("signal received, shutting down");
break;
},
io = incoming.next() => {
let io = match io {
Some(Ok(io)) => io,
Some(Err(e)) => {
trace!("error accepting connection: {:#}", e);
continue;
},
None => {
break
},
};

trace!("connection accepted");

let hyper_io = TokioIo::new(io);
let hyper_svc = service.clone();

serve_http_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()), None).await;
}
}
}

if graceful {
let _ = signal_tx.send(());
drop(signal_rx);
trace!(
"waiting for {} connections to close",
signal_tx.receiver_count()
);

// Wait for all connections to close
signal_tx.closed().await;
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::{TcpStream, TcpListener};
use tokio::sync::oneshot;
use bytes::Bytes;
use http_body_util::{BodyExt, Empty, Full};
use hyper::{body::Incoming, Request, Response, StatusCode};
use tokio_stream::wrappers::TcpListenerStream;
use tower::ServiceExt;

// Echo service
async fn echo(req: Request<Incoming>) -> Result<Response<Full<Bytes>>, hyper::Error> {
match (req.method(), req.uri().path()) {
(&hyper::Method::GET, "/") => Ok(Response::new(Full::new(Bytes::from("Hello, World!")))),
(&hyper::Method::POST, "/echo") => {
let body = req.collect().await?.to_bytes();
Ok(Response::new(Full::new(body)))
}
_ => {
let mut res = Response::new(Full::new(Bytes::from("Not Found")));
*res.status_mut() = StatusCode::NOT_FOUND;
Ok(res)
}
}
}

async fn setup_test_server(addr: SocketAddr) -> (TcpListenerStream, SocketAddr) {
let listener = TcpListener::bind(addr).await.unwrap();
let server_addr = listener.local_addr().unwrap();
let incoming = TcpListenerStream::new(listener);
(incoming, server_addr)
}

async fn send_request(addr: SocketAddr, req: Request<Empty<Bytes>>) -> hyper::Result<Response<Incoming>> {
let stream = TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::spawn(async move {
if let Err(err) = conn.await {
eprintln!("Connection failed: {:?}", err);
}
});

sender.send_request(req).await
}

#[tokio::test]
async fn test_serve_http_with_shutdown_basic() {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let (incoming, server_addr) = setup_test_server(addr).await;

let (shutdown_tx, shutdown_rx) = oneshot::channel();

let tower_service_fn = tower::service_fn(echo);
let hyper_service = TowerToHyperService::new(tower_service_fn);

let server = tokio::spawn(serve_http_with_shutdown(
hyper_service,
incoming,
Some(async {
shutdown_rx.await.ok();
}),
));

// Test GET request
let req = Request::builder()
.uri("/")
.body(Empty::<Bytes>::new())
.unwrap();
let res = send_request(server_addr, req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Hello, World!");

// Test POST request
let req = Request::builder()
.method(hyper::Method::POST)
.uri("/echo")
.body(Empty::<Bytes>::new())
.unwrap();
let res = send_request(server_addr, req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);

// Test 404 response
let req = Request::builder()
.uri("/not_found")
.body(Empty::<Bytes>::new())
.unwrap();
let res = send_request(server_addr, req).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);

// Shutdown the server
shutdown_tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), server).await
.expect("Server didn't shut down within the timeout period")
.unwrap()
.unwrap();
}

#[tokio::test]
async fn test_serve_http_with_concurrent_requests() {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let (incoming, server_addr) = setup_test_server(addr).await;

let (shutdown_tx, shutdown_rx) = oneshot::channel();

let tower_service_fn = tower::service_fn(echo);
let hyper_service = TowerToHyperService::new(tower_service_fn);

let server = tokio::spawn(serve_http_with_shutdown(
hyper_service,
incoming,
Some(async {
shutdown_rx.await.ok();
}),
));

let mut handles = vec![];
for _ in 0..10 {
let handle = tokio::spawn(async move {
let req = Request::builder()
.uri("/")
.body(Empty::<Bytes>::new())
.unwrap();
let res = send_request(server_addr, req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
});
handles.push(handle);
}

for handle in handles {
handle.await.unwrap();
}

// Shutdown the server
shutdown_tx.send(()).unwrap();
tokio::time::timeout(Duration::from_secs(5), server).await
.expect("Server didn't shut down within the timeout period")
.unwrap()
.unwrap();
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ mod tcp;
mod tls;

pub(crate) type Error = Box<dyn std::error::Error + Send + Sync>;


0 comments on commit 82e39be

Please sign in to comment.