From f3a6a4607eeaec7bc95bbd7027943523f258f69d Mon Sep 17 00:00:00 2001 From: Aleks Todorov Date: Sat, 10 Feb 2024 15:10:32 +0000 Subject: [PATCH] Implement user fetching After a user is authenticated, allow fetching the current user at /api/user. All requests can receive the user authenticated for the request by accepting a user parameter. Also refactored the Authenticate error into EncodeJwt (to partner with the new DecodeJwt error). In addition, implemented Serialize / Deserialize for Id so that it doesn't have to be manually converted into a string anywhere. --- backend-rs/src/controllers/errors.rs | 16 ++++-- backend-rs/src/controllers/user.rs | 77 +++++++++++++++++++++++----- backend-rs/src/main.rs | 6 ++- backend-rs/src/models/id.rs | 65 ++++++++++++++++++++--- backend-rs/src/models/user.rs | 10 ++++ 5 files changed, 149 insertions(+), 25 deletions(-) diff --git a/backend-rs/src/controllers/errors.rs b/backend-rs/src/controllers/errors.rs index a2d6277..8500b39 100644 --- a/backend-rs/src/controllers/errors.rs +++ b/backend-rs/src/controllers/errors.rs @@ -1,4 +1,3 @@ -use crate::controllers::user::AuthError; use crate::models::user; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; @@ -52,13 +51,18 @@ pub enum HandlerError { UsernameTaken, #[error("Email Address must not be taken.")] EmailTaken, + #[error("Authentication token is invalid.")] + DecodeJwt(jsonwebtoken::errors::Error), + + #[error("No user was logged in but a user is required.")] + UserRequired, #[error("Database transaction failed.")] Database(#[from] DbErr), #[error("Could not create user.")] CreateUser(#[from] user::CreateError), - #[error("Could not authenticate: {0}.")] - Authenticate(#[from] AuthError), + #[error("Could not encode JWT.")] + EncodeJwt(jsonwebtoken::errors::Error), } impl HandlerError { @@ -78,9 +82,13 @@ impl IntoResponse for HandlerError { match self { HandlerError::UsernameTaken => self.failed_validation(StatusCode::BAD_REQUEST, "username"), HandlerError::EmailTaken => self.failed_validation(StatusCode::BAD_REQUEST, "emailAddress"), + HandlerError::DecodeJwt(_) => self.into_generic(StatusCode::BAD_REQUEST), + + HandlerError::UserRequired => self.into_generic(StatusCode::UNAUTHORIZED), + HandlerError::Database(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR), HandlerError::CreateUser(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR), - HandlerError::Authenticate(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR), + HandlerError::EncodeJwt(_) => self.into_generic(StatusCode::INTERNAL_SERVER_ERROR), } .into_response() } diff --git a/backend-rs/src/controllers/user.rs b/backend-rs/src/controllers/user.rs index 10e2298..c6cdb23 100644 --- a/backend-rs/src/controllers/user.rs +++ b/backend-rs/src/controllers/user.rs @@ -1,14 +1,17 @@ use crate::app_state::AppState; use crate::controllers::errors::HandlerError; use crate::controllers::validate::ValidatedJson; +use crate::models::id::Id; use crate::models::user::{self, CreateUser, User}; -use axum::extract::State; +use axum::extract::{FromRequestParts, State}; +use axum::http::request::Parts; use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::{async_trait, Json, RequestPartsExt}; use axum_extra::extract::cookie::{Cookie, CookieJar}; use chrono::{Duration, Utc}; -use jsonwebtoken::{EncodingKey, Header}; +use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; -use thiserror::Error; use validator::Validate; const AUTH_COOKIE_KEY: &str = "Authorization"; @@ -16,27 +19,26 @@ const SECRET: &str = "123456"; #[derive(Debug, Serialize, Deserialize)] struct Claims { - uid: String, + uid: Id, exp: usize, } -#[derive(Error, Debug)] -pub enum AuthError { - #[error("could not encode JWT")] - EncodeJwt(#[from] jsonwebtoken::errors::Error), -} - -fn authenticate(jar: CookieJar, user: &User, lifespan: Duration) -> Result { +fn authenticate( + jar: CookieJar, + user: &User, + lifespan: Duration, +) -> Result { let cookie = Cookie::build(( AUTH_COOKIE_KEY, jsonwebtoken::encode( &Header::default(), &Claims { - uid: user.id.to_string(), + uid: user.id, exp: (Utc::now() + lifespan).timestamp().try_into().unwrap(), }, &EncodingKey::from_secret(SECRET.as_ref()), - )?, + ) + .map_err(HandlerError::EncodeJwt)?, )) .max_age(time::Duration::seconds(lifespan.num_seconds())) .path("/") @@ -84,3 +86,52 @@ pub async fn register( Ok((StatusCode::CREATED, cookies)) } + +#[async_trait] +impl FromRequestParts for User { + type Rejection = HandlerError; + + async fn from_request_parts( + parts: &mut Parts, + state: &AppState, + ) -> Result { + let cookies = parts.extract::().await.unwrap(); + let token = cookies + .get(AUTH_COOKIE_KEY) + .ok_or(HandlerError::UserRequired)? + .value(); + let claims = jsonwebtoken::decode::( + token, + &DecodingKey::from_secret(SECRET.as_ref()), + &Validation::default(), + ) + .map_err(HandlerError::DecodeJwt)? + .claims; + + user::find(state, claims.uid) + .await? + .ok_or(HandlerError::UserRequired) + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct UserResponse { + id: Id, + username: String, +} + +pub async fn get_user(user: Option) -> Response { + if let Some(u) = user { + ( + StatusCode::OK, + Json(UserResponse { + id: u.id, + username: u.visible_username().to_string(), + }), + ) + .into_response() + } else { + StatusCode::NOT_FOUND.into_response() + } +} diff --git a/backend-rs/src/main.rs b/backend-rs/src/main.rs index 8a16d27..b342cd2 100644 --- a/backend-rs/src/main.rs +++ b/backend-rs/src/main.rs @@ -52,7 +52,8 @@ async fn main() -> Result<(), AppError> { .map_err(AppError::MigrateDatabase)?; let app = Router::new() - .route("/api/auth/register", routing::post(user::register)) + .route("/auth/register", routing::post(user::register)) + .route("/user", routing::get(user::get_user)) .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) @@ -63,6 +64,7 @@ async fn main() -> Result<(), AppError> { connection, user_snowflake: SnowflakeIdBucket::new(1, 1), }); + let api = Router::new().nest("/api", app); let listener = TcpListener::bind(SocketAddr::from((config.app.host, config.app.port))) .await @@ -75,7 +77,7 @@ async fn main() -> Result<(), AppError> { .map_err(AppError::GetListenerAddress)? ); - axum::serve(listener, app) + axum::serve(listener, api) .await .map_err(AppError::ServeApp)?; diff --git a/backend-rs/src/models/id.rs b/backend-rs/src/models/id.rs index b6e1a86..4a00aa1 100644 --- a/backend-rs/src/models/id.rs +++ b/backend-rs/src/models/id.rs @@ -1,6 +1,9 @@ use sea_orm::{DbErr, DeriveValueType, TryFromU64}; +use serde::de::{Unexpected, Visitor}; +use serde::{Deserialize, Serialize}; +use std::fmt::Formatter; -#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, DeriveValueType)] pub struct Id(i64); impl Id { @@ -21,12 +24,62 @@ impl TryFromU64 for Id { } } +fn encode(id: Id) -> Result { + let mut encoded = String::new(); + bs58::encode(id.0.to_be_bytes()).onto(&mut encoded)?; + + Ok(encoded) +} + +fn decode(str: &str) -> Result { + let mut decoded = [0; 8]; + bs58::decode(str).onto(&mut decoded)?; + + Ok(Id::from(i64::from_be_bytes(decoded))) +} + impl ToString for Id { fn to_string(&self) -> String { - let mut string = String::new(); - bs58::encode(self.0.to_be_bytes()) - .onto(&mut string) - .expect("could not encode ID into base 58"); - string + encode(*self).expect("could not encode ID into base 58") + } +} + +impl Serialize for Id { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let string = encode(*self).map_err(|err| { + serde::ser::Error::custom(format!("could not encode ID into base58: {}", err)) + })?; + serializer.serialize_str(&string) + } +} + +impl<'de> Deserialize<'de> for Id { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(IdVisitor) + } +} + +struct IdVisitor; + +impl<'de> Visitor<'de> for IdVisitor { + type Value = Id; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("a base58-encoded string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + decode(v).map_err(|_| { + serde::de::Error::invalid_value(Unexpected::Str(v), &"a valid base-58 encoded string") + }) } } diff --git a/backend-rs/src/models/user.rs b/backend-rs/src/models/user.rs index 5b7f6eb..95f7280 100644 --- a/backend-rs/src/models/user.rs +++ b/backend-rs/src/models/user.rs @@ -26,6 +26,12 @@ pub struct Model { pub password: Option, } +impl Model { + pub fn visible_username(&self) -> &str { + self.username.as_ref().map_or("Anonymous", String::as_str) + } +} + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation {} @@ -83,3 +89,7 @@ pub async fn create(state: &mut AppState, user: CreateUser<'_>) -> Result Result, DbErr> { + Entity::find_by_id(id).one(&state.connection).await +}