Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: use rsa instead of openssl #116

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ deprecated = "deny"
[lints.clippy]
perf = { level = "deny", priority = -1 }
complexity = { level = "deny", priority = -1 }
dbg_macro = "deny"
dbg_macro = "deny"
inefficient_to_string = "deny"
items-after-statements = "deny"
items-after-statements = "deny"
implicit_clone = "deny"
wildcard_imports = "deny"
cast_lossless = "deny"
Expand All @@ -41,10 +41,11 @@ reqwest = { version = "0.11.27", features = ["json", "stream"] }
reqwest-middleware = "0.2.5"
tracing = "0.1.40"
base64 = "0.22.1"
openssl = "0.10.64"
rand = "0.8.5"
rsa = "0.9.6"
once_cell = "1.19.0"
http = "0.2.12"
sha2 = "0.10.8"
sha2 = { version = "0.10.8", features = ["oid"] }
kwaa marked this conversation as resolved.
Show resolved Hide resolved
thiserror = "1.0.59"
derive_builder = "0.20.0"
itertools = "0.12.1"
Expand All @@ -61,14 +62,19 @@ bytes = "1.6.0"
futures-core = { version = "0.3.30", default-features = false }
pin-project-lite = "0.2.14"
activitystreams-kinds = "0.3.0"
regex = { version = "1.10.5", default-features = false, features = ["std", "unicode"] }
regex = { version = "1.10.5", default-features = false, features = [
"std",
"unicode",
] }
tokio = { version = "1.37.0", features = [
"sync",
"rt",
"rt-multi-thread",
"time",
] }
diesel = { version = "2.1.6", features = ["postgres"], default-features = false, optional = true }
diesel = { version = "2.1.6", features = [
"postgres",
], default-features = false, optional = true }
futures = "0.3.30"
moka = { version = "0.12.7", features = ["future"] }

Expand All @@ -82,11 +88,10 @@ axum = { version = "0.6.20", features = [
], default-features = false, optional = true }
tower = { version = "0.4.13", optional = true }
hyper = { version = "0.14", optional = true }
http-body-util = {version = "0.1.1", optional = true }
http-body-util = { version = "0.1.1", optional = true }

[dev-dependencies]
anyhow = "1.0.82"
rand = "0.8.5"
env_logger = "0.11.3"
tower-http = { version = "0.5.2", features = ["map-request-body", "util"] }
axum = { version = "0.6.20", features = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might as well add the toml formatting to CI. You can copy it from the Lemmy repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might as well add the toml formatting to CI. You can copy it from the Lemmy repo.

I'll take a look, but maybe changing CI as another PR would be better.

Expand Down
10 changes: 5 additions & 5 deletions src/activity_sending.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use futures::StreamExt;
use http::StatusCode;
use httpdate::fmt_http_date;
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Response,
};
use reqwest_middleware::ClientWithMiddleware;
use rsa::{pkcs8::DecodePrivateKey, RsaPrivateKey};
use serde::Serialize;
use std::{
fmt::{Debug, Display},
Expand All @@ -37,7 +37,7 @@ pub struct SendActivityTask {
pub(crate) activity_id: Url,
pub(crate) activity: Bytes,
pub(crate) inbox: Url,
pub(crate) private_key: PKey<Private>,
pub(crate) private_key: RsaPrivateKey,
pub(crate) http_signature_compat: bool,
}

Expand Down Expand Up @@ -172,7 +172,7 @@ where
pub(crate) async fn get_pkey_cached<ActorType>(
data: &Data<impl Clone>,
actor: &ActorType,
) -> Result<PKey<Private>, Error>
) -> Result<RsaPrivateKey, Error>
where
ActorType: Actor,
{
Expand All @@ -189,13 +189,13 @@ where

// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
let pkey = tokio::task::spawn_blocking(move || {
PKey::private_key_from_pem(private_key_pem.as_bytes()).map_err(|err| {
RsaPrivateKey::from_pkcs8_pem(&private_key_pem).map_err(|err| {
Error::Other(format!("Could not create private key from PEM data:{err}"))
})
})
.await
.map_err(|err| Error::Other(format!("Error joining: {err}")))??;
std::result::Result::<PKey<Private>, Error>::Ok(pkey)
std::result::Result::<RsaPrivateKey, Error>::Ok(pkey)
})
.await
.map_err(|e| Error::Other(format!("cloned error: {e}")))
Expand Down
10 changes: 5 additions & 5 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use async_trait::async_trait;
use derive_builder::Builder;
use dyn_clone::{clone_trait_object, DynClone};
use moka::future::Cache;
use openssl::pkey::{PKey, Private};
use reqwest_middleware::ClientWithMiddleware;
use rsa::{pkcs8::DecodePrivateKey, RsaPrivateKey};
use serde::de::DeserializeOwned;
use std::{
ops::Deref,
Expand Down Expand Up @@ -80,12 +80,12 @@ pub struct FederationConfig<T: Clone> {
/// This can be used to implement secure mode federation.
/// <https://docs.joinmastodon.org/spec/activitypub/#secure-mode>
#[builder(default = "None", setter(custom))]
pub(crate) signed_fetch_actor: Option<Arc<(Url, PKey<Private>)>>,
pub(crate) signed_fetch_actor: Option<Arc<(Url, RsaPrivateKey)>>,
#[builder(
default = "Cache::builder().max_capacity(10000).build()",
setter(custom)
)]
pub(crate) actor_pkey_cache: Cache<Url, PKey<Private>>,
pub(crate) actor_pkey_cache: Cache<Url, RsaPrivateKey>,
/// Queue for sending outgoing activities. Only optional to make builder work, its always
/// present once constructed.
#[builder(setter(skip))]
Expand Down Expand Up @@ -200,8 +200,8 @@ impl<T: Clone> FederationConfigBuilder<T> {
.private_key_pem()
.expect("actor does not have a private key to sign with");

let private_key = PKey::private_key_from_pem(private_key_pem.as_bytes())
.expect("Could not decode PEM data");
let private_key =
RsaPrivateKey::from_pkcs8_pem(&private_key_pem).expect("Could not decode PEM data");
self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key))));
self
}
Expand Down
21 changes: 18 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

use crate::fetch::webfinger::WebFingerError;
use http_signature_normalization_reqwest::SignError;
use openssl::error::ErrorStack;
use rsa::{
errors::Error as RsaError,
pkcs8::{spki::Error as SpkiError, Error as Pkcs8Error},
};
use std::string::FromUtf8Error;
use tokio::task::JoinError;
use url::Url;
Expand Down Expand Up @@ -80,8 +83,20 @@ pub enum Error {
Other(String),
}

impl From<ErrorStack> for Error {
fn from(value: ErrorStack) -> Self {
impl From<RsaError> for Error {
fn from(value: RsaError) -> Self {
Error::Other(value.to_string())
}
}

impl From<Pkcs8Error> for Error {
fn from(value: Pkcs8Error) -> Self {
Error::Other(value.to_string())
}
}

impl From<SpkiError> for Error {
fn from(value: SpkiError) -> Self {
Error::Other(value.to_string())
}
}
Expand Down
81 changes: 41 additions & 40 deletions src/http_signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ use http_signature_normalization_reqwest::{
DefaultSpawner,
};
use once_cell::sync::Lazy;
use openssl::{
hash::MessageDigest,
pkey::{PKey, Private},
rsa::Rsa,
sign::{Signer, Verifier},
};
use reqwest::Request;
use reqwest_middleware::RequestBuilder;
use rsa::{
pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey, LineEnding},
Pkcs1v15Sign,
RsaPrivateKey,
RsaPublicKey,
};
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind, time::Duration};
use std::{collections::BTreeMap, fmt::Debug, time::Duration};
use tracing::debug;
use url::Url;

Expand All @@ -46,27 +46,23 @@ pub struct Keypair {
impl Keypair {
/// Helper method to turn this into an openssl private key
#[cfg(test)]
pub(crate) fn private_key(&self) -> Result<PKey<Private>, anyhow::Error> {
Ok(PKey::private_key_from_pem(self.private_key.as_bytes())?)
pub(crate) fn private_key(&self) -> Result<RsaPrivateKey, anyhow::Error> {
use rsa::pkcs8::DecodePrivateKey;

Ok(RsaPrivateKey::from_pkcs8_pem(&self.private_key)?)
}
}

/// Generate a random asymmetric keypair for ActivityPub HTTP signatures.
pub fn generate_actor_keypair() -> Result<Keypair, std::io::Error> {
let rsa = Rsa::generate(2048)?;
let pkey = PKey::from_rsa(rsa)?;
let public_key = pkey.public_key_to_pem()?;
let private_key = pkey.private_key_to_pem_pkcs8()?;
let key_to_string = |key| match String::from_utf8(key) {
Ok(s) => Ok(s),
Err(e) => Err(std::io::Error::new(
ErrorKind::Other,
format!("Failed converting key to string: {}", e),
)),
};
pub fn generate_actor_keypair() -> Result<Keypair, Error> {
let mut rng = rand::thread_rng();
let rsa = RsaPrivateKey::new(&mut rng, 2048)?;
let pkey = RsaPublicKey::from(&rsa);
let public_key = pkey.to_public_key_pem(LineEnding::default())?;
let private_key = rsa.to_pkcs8_pem(LineEnding::default())?.to_string();
Ok(Keypair {
private_key: key_to_string(private_key)?,
public_key: key_to_string(public_key)?,
private_key,
public_key,
})
}

Expand All @@ -83,7 +79,7 @@ pub(crate) async fn sign_request(
request_builder: RequestBuilder,
actor_id: &Url,
activity: Bytes,
private_key: PKey<Private>,
private_key: RsaPrivateKey,
http_signature_compat: bool,
) -> Result<Request, Error> {
static CONFIG: Lazy<Config<DefaultSpawner>> =
Expand All @@ -106,10 +102,10 @@ pub(crate) async fn sign_request(
Sha256::new(),
activity,
move |signing_string| {
let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?;
signer.update(signing_string.as_bytes())?;

Ok(Base64.encode(signer.sign_to_vec()?)) as Result<_, Error>
Ok(Base64.encode(private_key.sign(
Pkcs1v15Sign::new::<Sha256>(),
&Sha256::digest(signing_string.as_bytes()),
)?)) as Result<_, Error>
},
)
.await
Expand Down Expand Up @@ -205,15 +201,19 @@ fn verify_signature_inner(
"Verifying with key {}, message {}",
&public_key, &signing_string
);
let public_key = PKey::public_key_from_pem(public_key.as_bytes())?;
let mut verifier = Verifier::new(MessageDigest::sha256(), &public_key)?;
verifier.update(signing_string.as_bytes())?;
let public_key = RsaPublicKey::from_public_key_pem(public_key)?;

let base64_decoded = Base64
.decode(signature)
.map_err(|err| Error::Other(err.to_string()))?;

Ok(verifier.verify(&base64_decoded)?)
Ok(public_key
.verify(
Pkcs1v15Sign::new::<Sha256>(),
&Sha256::digest(signing_string.as_bytes()),
&base64_decoded,
)
.is_ok())
})?;

if verified {
Expand Down Expand Up @@ -284,6 +284,7 @@ pub mod test {
use crate::activity_sending::generate_request_headers;
use reqwest::Client;
use reqwest_middleware::ClientWithMiddleware;
use rsa::{pkcs1::DecodeRsaPrivateKey, pkcs8::DecodePrivateKey};
use std::str::FromStr;

static ACTOR_ID: Lazy<Url> = Lazy::new(|| Url::parse("https://example.com/u/alice").unwrap());
Expand All @@ -306,7 +307,7 @@ pub mod test {
request_builder,
&ACTOR_ID,
"my activity".into(),
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
RsaPrivateKey::from_pkcs8_pem(&test_keypair().private_key).unwrap(),
// set this to prevent created/expires headers to be generated and inserted
// automatically from current time
true,
Expand Down Expand Up @@ -342,7 +343,7 @@ pub mod test {
request_builder,
&ACTOR_ID,
"my activity".to_string().into(),
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
RsaPrivateKey::from_pkcs8_pem(&test_keypair().private_key).unwrap(),
false,
)
.await
Expand Down Expand Up @@ -378,13 +379,13 @@ pub mod test {
}

pub fn test_keypair() -> Keypair {
let rsa = Rsa::private_key_from_pem(PRIVATE_KEY.as_bytes()).unwrap();
let pkey = PKey::from_rsa(rsa).unwrap();
let private_key = pkey.private_key_to_pem_pkcs8().unwrap();
let public_key = pkey.public_key_to_pem().unwrap();
let rsa = RsaPrivateKey::from_pkcs1_pem(PRIVATE_KEY).unwrap();
let pkey = RsaPublicKey::from(&rsa);
let public_key = pkey.to_public_key_pem(LineEnding::default()).unwrap();
let private_key = rsa.to_pkcs8_pem(LineEnding::default()).unwrap().to_string();
Keypair {
private_key: String::from_utf8(private_key).unwrap(),
public_key: String::from_utf8(public_key).unwrap(),
private_key,
public_key,
}
}

Expand Down