Skip to content

Commit

Permalink
feat: well factored http connection handling
Browse files Browse the repository at this point in the history
  • Loading branch information
0xAlcibiades committed Sep 10, 2024
1 parent f5c8c60 commit 2d27fb7
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 27 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ tokio = { version = "1.40.0", features = ["net"] }
tokio-rustls = "0.26.0"
tower = { version = "0.5.1", features = ["util"] }
tracing = "0.1.40"
http-body = "1.0.1"
tokio-stream = "0.1.16"
bytes = "1.7.1"
pin-project = "1.1.5"

[dev-dependencies]
tokio = { version = "1.40.0", features = ["macros"] }
207 changes: 180 additions & 27 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,137 @@
// Crate links:
// - tokio: https://docs.rs/tokio/latest/tokio/
// - rustls: https://docs.rs/rustls/latest/rustls/
// - tokio_rustls:https://docs.rs/tokio-rustls/latest/tokio_rustls/
// - hyper: https://docs.rs/hyper/latest/hyper/
// - tower: https://docs.rs/tower/latest/tower/
//
// We take a `SocketAddr` to bind the server to a specific address.
// We use `TcpListener` to bind the server to the specified address at the TCP layer.
// We use `rustls` to create a new `ServerConfig` instance.
// We use `tokio_rustls` to create a new `TlsAcceptor` instance.
// We use `hyper_util::server::conn::auto` to create a new `Connection` instance.
// Which then passes requests to a `hyper::service::Service` instance.
// Which then can optionally pass requests to a `tower::Service` instance.
// Behind that can be axum, tower, tonic,
// or any other service that implements the `tower::Service` trait.

use std::fs;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;

use bytes::Bytes;
use http::{Method, Request, Response, StatusCode};
use http_body::Body;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::body::{Bytes, Incoming};
use hyper::{body::Incoming, service::Service as HyperService};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
use hyper_util::server::conn::auto::Builder as HttpConnBuilder;
use hyper_util::server::conn::auto::HttpServerConnExec;
use hyper_util::service::TowerToHyperService;
use pin_project::pin_project;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::ServerConfig;
use std::error::Error as StdError;
use std::future::pending;
use std::net::SocketAddr;
use std::pin::{pin, Pin};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{fmt, fs, future::Future, io};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::macros::support::poll_fn;
use tokio::net::TcpListener;
use tokio::time::sleep;
use tokio_rustls::TlsAcceptor;
use tokio_stream::Stream;
use tokio_stream::StreamExt as _;
use tower::{Service, ServiceBuilder};
use tracing::{debug, trace};

// From `futures-util` crate, borrowed since this is the only dependency hyper-server requires.
// LICENSE: MIT or Apache-2.0
// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`.
#[pin_project]
struct Fuse<F> {
#[pin]
inner: Option<F>,
}

impl<F> Future for Fuse<F>
where
F: Future,
{
type Output = F::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project().inner.as_pin_mut() {
Some(fut) => fut.poll(cx).map(|output| {
self.project().inner.set(None);
output
}),
None => Poll::Pending,
}
}
}

type Source = Box<dyn StdError + Send + Sync + 'static>;

/// Errors that originate from the client hyper-server;
pub struct Error {
inner: ErrorImpl,
}

struct ErrorImpl {
kind: Kind,
source: Option<Source>,
}

#[derive(Debug)]
pub(crate) enum Kind {
Transport,
}

impl Error {
pub(crate) fn new(kind: Kind) -> Self {
Self {
inner: ErrorImpl { kind, source: None },
}
}

pub(crate) fn with(mut self, source: impl Into<Source>) -> Self {
self.inner.source = Some(source.into());
self
}

pub(crate) fn from_source(
source: impl Into<Error> + std::error::Error + std::marker::Send + std::marker::Sync + 'static,
) -> Self {
Error::new(Kind::Transport).with(source)
}

fn description(&self) -> &str {
match &self.inner.kind {
Kind::Transport => "transport error",
}
}
}

impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_tuple("hyper_server::Error");

f.field(&self.inner.kind);

if let Some(source) = &self.inner.source {
f.field(source);
}

f.finish()
}
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.description())
}
}

impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
self.inner
.source
.as_ref()
.map(|source| &**source as &(dyn StdError + 'static))
}
}

async fn sleep_or_pending(wait_for: Option<Duration>) {
match wait_for {
Some(wait) => sleep(wait).await,
None => pending().await,
};
}

#[derive(Debug, Clone)]
pub struct Logger<S> {
Expand All @@ -43,6 +143,7 @@ impl<S> Logger<S> {
}
}
type Req = Request<Incoming>;

impl<S> Service<Req> for Logger<S>
where
S: Service<Req> + Clone,
Expand Down Expand Up @@ -114,9 +215,9 @@ async fn echo(req: Request<Incoming>) -> Result<Response<Full<Bytes>>, hyper::Er
Ok(response)
}

pub struct HyperServer {}
pub struct Server {}

impl HyperServer {
impl Server {
pub async fn serve(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Get a random port from the OS
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
Expand Down Expand Up @@ -173,7 +274,7 @@ impl HyperServer {
};

// Serve the http connection
if let Err(err) = Builder::new(TokioExecutor::new())
if let Err(err) = HttpConnBuilder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(tls_stream), service)
.await
{
Expand All @@ -182,6 +283,58 @@ impl HyperServer {
});
}
}

/// Serves a single HTTP connection from a hyper service backend
async fn serve_connection<B, IO, S, E>(
hyper_io: IO,
hyper_svc: S,
builder: HttpConnBuilder<E>,
mut watcher: Option<tokio::sync::watch::Receiver<()>>,
max_connection_age: Option<Duration>,
) where
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
S: HyperService<Request<Incoming>, Response = Response<B>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
E: HttpServerConnExec<S::Future, B> + Send + Sync + 'static,
{
tokio::spawn(async move {
{
let mut sig = pin!(Fuse {
inner: watcher.as_mut().map(|w| w.changed()),
});

let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc));

let sleep = sleep_or_pending(max_connection_age);
tokio::pin!(sleep);

loop {
tokio::select! {
rv = &mut conn => {
if let Err(err) = rv {
debug!("failed serving connection: {:#}", err);
}
break;
},
_ = &mut sleep => {
conn.as_mut().graceful_shutdown();
sleep.set(sleep_or_pending(None));
},
_ = &mut sig => {
conn.as_mut().graceful_shutdown();
}
}
}
}

drop(watcher);
trace!("connection closed");
});
}
}

#[cfg(test)]
Expand Down

0 comments on commit 2d27fb7

Please sign in to comment.