Skip to content

Commit

Permalink
Update to rustls 0.22 (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Jan 21, 2024
1 parent 04b33e7 commit ad7d00d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 75 deletions.
11 changes: 6 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ log = "0.4.17"
mime = {version = "0.3.16", optional = true}
multipart = {version = "0.18.0", default-features = false, features = ["client"], optional = true}
native-tls = {version = "0.2.10", optional = true}
rustls-native-certs = { version = "0.6", optional = true}
rustls-opt-dep = {package = "rustls", version = "0.21.0", features = ["dangerous_configuration"], optional = true}
rustls-native-certs = { version = "0.7", optional = true}
rustls-opt-dep = {package = "rustls", version = "0.22.0", optional = true}
rustls-pki-types = "1"
serde = {version = "1.0.143", optional = true}
serde_json = {version = "1.0.83", optional = true}
serde_urlencoded = {version = "0.7.1", optional = true}
url = "2.2.2"
webpki-roots = {version = "0.25.1", optional = true}
webpki-roots = {version = "0.26.0", optional = true}

[dev-dependencies]
anyhow = "1.0.61"
Expand All @@ -40,9 +41,9 @@ http02 = {package = "http", version = "0.2"}
hyper = "0.14.20"
lazy_static = "1.4.0"
multipart = {version = "0.18.0", default-features = false, features = ["server"]}
rustls-pemfile = "1.0.3"
rustls-pemfile = "2"
tokio = {version = "1.20.1", features = ["full"]}
tokio-rustls = "0.24.1"
tokio-rustls = "0.25.0"
tokio-stream = {version = "0.1.9", features = ["net"]}
warp = "0.3.2"

Expand Down
12 changes: 12 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ pub enum ErrorKind {
InvalidMimeType(String),
/// TLS was not enabled by features.
TlsDisabled,
/// Empty cert store
#[cfg(all(feature = "__rustls", not(feature = "tls-native")))]
ServerCertVerifier(rustls::client::VerifierBuilderError),
}

/// A type that contains all the errors that can possibly occur while accessing an HTTP server.
Expand Down Expand Up @@ -132,6 +135,8 @@ impl Display for Error {
InvalidDNSName(ref e) => write!(w, "Invalid DNS name: {e}"),
InvalidMimeType(ref e) => write!(w, "Invalid mime type: {e}"),
TlsDisabled => write!(w, "TLS is disabled, activate one of the tls- features"),
#[cfg(all(feature = "__rustls", not(feature = "tls-native")))]
ServerCertVerifier(ref e) => write!(w, "Invalid certificate: {e}"),
}
}
}
Expand Down Expand Up @@ -216,6 +221,13 @@ impl From<InvalidResponseKind> for Error {
}
}

#[cfg(all(feature = "__rustls", not(feature = "tls-native")))]
impl From<rustls::client::VerifierBuilderError> for Error {
fn from(err: rustls::client::VerifierBuilderError) -> Error {
Error(Box::new(ErrorKind::ServerCertVerifier(err)))
}
}

impl From<Error> for io::Error {
fn from(err: Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
Expand Down
101 changes: 68 additions & 33 deletions src/tls/rustls_impl.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
use std::convert::TryInto;
use std::convert::TryFrom;
use std::fmt;
use std::io;
use std::io::prelude::*;
use std::sync::Arc;
use std::time::SystemTime;

#[cfg(feature = "tls-rustls-webpki-roots")]
use rustls::OwnedTrustAnchor;
use rustls::{
client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier},
ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned,
client::{
danger::{DangerousClientConfigBuilder, HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
WebPkiServerVerifier,
},
ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme, StreamOwned,
};
#[cfg(feature = "tls-rustls-native-roots")]
use rustls_native_certs::load_native_certs;
use rustls_pki_types::{CertificateDer, ServerName, UnixTime};
#[cfg(feature = "tls-rustls-webpki-roots")]
use webpki_roots::TLS_SERVER_ROOTS;

use crate::{Error, ErrorKind, Result};

pub type Certificate = rustls::Certificate;
pub type Certificate = CertificateDer<'static>;

pub struct TlsHandshaker {
inner: Option<Arc<ClientConfig>>,
Expand Down Expand Up @@ -59,36 +60,33 @@ impl TlsHandshaker {
let mut root_store = RootCertStore::empty();

#[cfg(feature = "tls-rustls-webpki-roots")]
root_store.add_server_trust_anchors(TLS_SERVER_ROOTS.iter().map(|root| {
OwnedTrustAnchor::from_subject_spki_name_constraints(root.subject, root.spki, root.name_constraints)
}));
root_store.extend(TLS_SERVER_ROOTS.iter().cloned());

#[cfg(feature = "tls-rustls-native-roots")]
for native_cert in load_native_certs()? {
let cert = rustls::Certificate(native_cert.0);
for cert in load_native_certs()? {
// Inspired by https://github.com/seanmonstar/reqwest/blob/231b18f83572836c674404b33cb1ca8b35ca3e36/src/async_impl/client.rs#L363-L365
// Native certificate stores often include certificates with invalid formats,
// but we don't want those invalid entries to invalidate the entire process of
// loading native root certificates
if let Err(e) = root_store.add(&cert) {
if let Err(e) = root_store.add(cert) {
warn!("Could not load native root certificate: {}", e);
}
}

for cert in &self.additional_certs {
for cert in self.additional_certs.iter().cloned() {
root_store.add(cert)?;
}

let config = Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(CustomCertVerifier {
upstream: WebPkiVerifier::new(root_store, None),
accept_invalid_certs: self.accept_invalid_certs,
accept_invalid_hostnames: self.accept_invalid_hostnames,
}))
.with_no_client_auth(),
);
let config = DangerousClientConfigBuilder {
cfg: ClientConfig::builder(),
}
.with_custom_certificate_verifier(Arc::new(CustomCertVerifier {
upstream: WebPkiServerVerifier::builder(root_store.into()).build()?,
accept_invalid_certs: self.accept_invalid_certs,
accept_invalid_hostnames: self.accept_invalid_hostnames,
}))
.with_no_client_auth()
.into();

self.inner = Some(Arc::clone(&config));

Expand All @@ -101,9 +99,9 @@ impl TlsHandshaker {
where
S: Read + Write,
{
let domain = domain
.try_into()
.map_err(|_| Error(Box::new(ErrorKind::InvalidDNSName(domain.to_owned()))))?;
let domain = ServerName::try_from(domain)
.map_err(|_| Error(Box::new(ErrorKind::InvalidDNSName(domain.to_owned()))))?
.to_owned();
let config = self.client_config()?;
let mut session = ClientConnection::new(config, domain)?;

Expand Down Expand Up @@ -184,24 +182,29 @@ where
}

struct CustomCertVerifier {
upstream: WebPkiVerifier,
upstream: Arc<WebPkiServerVerifier>,
accept_invalid_certs: bool,
accept_invalid_hostnames: bool,
}

impl fmt::Debug for CustomCertVerifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("CustomCertVerifier").finish()
}
}

impl ServerCertVerifier for CustomCertVerifier {
fn verify_server_cert(
&self,
end_entity: &Certificate,
intermediates: &[Certificate],
end_entity: &CertificateDer,
intermediates: &[CertificateDer],
server_name: &ServerName,
scts: &mut dyn Iterator<Item = &[u8]>,
ocsp_response: &[u8],
now: SystemTime,
now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
match self
.upstream
.verify_server_cert(end_entity, intermediates, server_name, scts, ocsp_response, now)
.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
{
Err(rustls::Error::NoCertificatesPresented | rustls::Error::InvalidCertificate(_))
if self.accept_invalid_certs =>
Expand All @@ -218,4 +221,36 @@ impl ServerCertVerifier for CustomCertVerifier {
upstream => upstream,
}
}

fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}

fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.upstream.supported_verify_schemes()
}
}
53 changes: 16 additions & 37 deletions tests/tools/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ use futures::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::{ServerConfig, Error as TLSError, PrivateKey, Certificate};
use tokio_rustls::rustls::{Error as TLSError, ServerConfig};

/// Represents errors that can occur building the TlsConfig
#[derive(Debug)]
pub(crate) enum TlsConfigError {
Io(io::Error),
/// An Error parsing the Certificate
CertParseError,
/// An Error parsing a Pkcs8 key
Pkcs8ParseError,
/// An Error parsing a Rsa key
RsaParseError,
/// An error from an empty key
EmptyKey,
/// An error from an invalid key
Expand All @@ -36,8 +32,6 @@ impl std::fmt::Display for TlsConfigError {
match self {
TlsConfigError::Io(err) => err.fmt(f),
TlsConfigError::CertParseError => write!(f, "certificate parse error"),
TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"),
TlsConfigError::RsaParseError => write!(f, "rsa parse error"),
TlsConfigError::EmptyKey => write!(f, "key contains no private key"),
TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {err}"),
}
Expand All @@ -46,6 +40,12 @@ impl std::fmt::Display for TlsConfigError {

impl std::error::Error for TlsConfigError {}

impl From<io::Error> for TlsConfigError {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}

/// Builder to set the configuration for the Tls server.
pub(crate) struct TlsConfigBuilder {
cert: Box<dyn Read + Send + Sync>,
Expand Down Expand Up @@ -81,41 +81,20 @@ impl TlsConfigBuilder {
self
}

pub(crate) fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
pub(crate) fn build(self) -> Result<ServerConfig, TlsConfigError> {
let mut cert_rdr = BufReader::new(self.cert);
let cert = rustls_pemfile::certs(&mut cert_rdr)
.map_err(|_| TlsConfigError::CertParseError)?.into_iter().map(Certificate).collect();

let key = {
// convert it to Vec<u8> to allow reading it again if key is RSA
let mut key_vec = Vec::new();
self.key.read_to_end(&mut key_vec).map_err(TlsConfigError::Io)?;

if key_vec.is_empty() {
return Err(TlsConfigError::EmptyKey);
}

let mut pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut key_vec.as_slice())
.map_err(|_| TlsConfigError::Pkcs8ParseError)?;

if !pkcs8.is_empty() {
PrivateKey(pkcs8.remove(0))
} else {
let mut rsa = rustls_pemfile::rsa_private_keys(&mut key_vec.as_slice())
.map_err(|_| TlsConfigError::RsaParseError)?;

if !rsa.is_empty() {
PrivateKey(rsa.remove(0))
} else {
return Err(TlsConfigError::EmptyKey);
}
}
.collect::<Result<Vec<_>, _>>()
.map_err(|_| TlsConfigError::CertParseError)?;

let Some(key) = rustls_pemfile::private_key(&mut BufReader::new(self.key))? else {
return Err(TlsConfigError::EmptyKey);
};

let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth().
with_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new()).map_err(TlsConfigError::InvalidKey)?;
.with_no_client_auth()
.with_single_cert_with_ocsp(cert, key, self.ocsp_resp)
.map_err(TlsConfigError::InvalidKey)?;
Ok(config)
}
}
Expand Down

0 comments on commit ad7d00d

Please sign in to comment.