diff --git a/backend-rs/src/controllers.rs b/backend-rs/src/controllers.rs index bb13e5d..0db32ff 100644 --- a/backend-rs/src/controllers.rs +++ b/backend-rs/src/controllers.rs @@ -1,3 +1,2 @@ pub mod errors; pub mod user; -pub mod validate; diff --git a/backend-rs/src/controllers/errors.rs b/backend-rs/src/controllers/errors.rs index 8500b39..cd2ddbc 100644 --- a/backend-rs/src/controllers/errors.rs +++ b/backend-rs/src/controllers/errors.rs @@ -1,40 +1,46 @@ use crate::models::user; +use axum::extract::rejection::JsonRejection; +use axum::extract::{FromRequest, Request}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; +use axum::{async_trait, Json}; +use convert_case::{Case, Casing}; use lazy_static::lazy_static; use sea_orm::DbErr; +use serde::de::DeserializeOwned; use serde::Serialize; use serde_json::json; use std::collections::HashMap; use thiserror::Error; +use validator::{Validate, ValidationErrors, ValidationErrorsKind}; lazy_static! { - pub static ref FAILED_VALIDATION: Errors = + static ref FAILED_VALIDATION: Errors = Errors::generic("Some inputs failed validation.".to_string()); } #[derive(Serialize, Default, Clone)] -pub struct Errors { +struct Errors { generic: Vec, specific: HashMap>, } impl Errors { - pub fn generic(value: String) -> Self { + fn generic(value: String) -> Self { let mut errors = Self::default(); errors.add_generic(value); errors } - pub fn add_generic(&mut self, value: String) { + fn add_generic(&mut self, value: String) { self.generic.push(value); } - pub fn add_one_specific(&mut self, key: String, value: String) { + fn add_one_specific(&mut self, key: String, value: String) { self.add_specific(key, vec![value]); } - pub fn add_specific(&mut self, key: String, value: Vec) { + fn add_specific(&mut self, key: String, value: Vec) { self.specific.insert(key, value); } } @@ -47,6 +53,10 @@ impl IntoResponse for Errors { #[derive(Error, Debug)] pub enum HandlerError { + #[error(transparent)] + JsonRejection(#[from] JsonRejection), + #[error(transparent)] + Validation(#[from] ValidationErrors), #[error("Username must not be taken.")] UsernameTaken, #[error("Email Address must not be taken.")] @@ -77,9 +87,45 @@ impl HandlerError { } } +fn format_error_messages(field: &str, errors: ValidationErrorsKind) -> Vec { + let title = field.to_case(Case::Title); + + match errors { + ValidationErrorsKind::Field(errors) => errors + .into_iter() + .map(|e| match e.code.as_ref() { + "email" => format!("{} must be a valid email address.", title), + "must_match" => format!("{} must be identical to {}.", title, e.message.unwrap()), + "length" => format!( + "{} must be between {} and {} characters long (currently {}).", + title, + e.params.get("min").unwrap(), + e.params.get("max").unwrap(), + e.params.get("value").unwrap().as_str().unwrap().len(), + ), + code => unimplemented!( + "error message is not implemented for message code '{}'", + code + ), + }) + .collect(), + ValidationErrorsKind::Struct(_) | ValidationErrorsKind::List(_) => { + panic!("unexpected error type") + } + } +} + impl IntoResponse for HandlerError { fn into_response(self) -> Response { match self { + HandlerError::JsonRejection(_) => self.into_generic(StatusCode::BAD_REQUEST), + HandlerError::Validation(validation_errors) => (StatusCode::BAD_REQUEST, { + let mut errors = FAILED_VALIDATION.clone(); + for (k, v) in validation_errors.into_errors() { + errors.add_specific(k.to_case(Case::Camel), format_error_messages(k, v)); + } + errors + }), HandlerError::UsernameTaken => self.failed_validation(StatusCode::BAD_REQUEST, "username"), HandlerError::EmailTaken => self.failed_validation(StatusCode::BAD_REQUEST, "emailAddress"), HandlerError::DecodeJwt(_) => self.into_generic(StatusCode::BAD_REQUEST), @@ -93,3 +139,22 @@ impl IntoResponse for HandlerError { .into_response() } } + +#[derive(Debug, Clone, Copy, Default)] +pub struct ValidatedJson(pub T); + +#[async_trait] +impl FromRequest for ValidatedJson +where + T: DeserializeOwned + Validate, + S: Send + Sync, + Json: FromRequest, +{ + type Rejection = HandlerError; + + async fn from_request(req: Request, state: &S) -> Result { + let Json(value) = Json::from_request(req, state).await?; + value.validate()?; + Ok(ValidatedJson(value)) + } +} diff --git a/backend-rs/src/controllers/user.rs b/backend-rs/src/controllers/user.rs index 0ea48c7..4122df3 100644 --- a/backend-rs/src/controllers/user.rs +++ b/backend-rs/src/controllers/user.rs @@ -1,5 +1,4 @@ -use crate::controllers::errors::HandlerError; -use crate::controllers::validate::ValidatedJson; +use crate::controllers::errors::{HandlerError, ValidatedJson}; use crate::models::id::Id; use crate::models::user::{self, CreateUser, User}; use crate::AppState; diff --git a/backend-rs/src/controllers/validate.rs b/backend-rs/src/controllers/validate.rs deleted file mode 100644 index 68794c4..0000000 --- a/backend-rs/src/controllers/validate.rs +++ /dev/null @@ -1,84 +0,0 @@ -use crate::controllers::errors::{self, Errors}; -use axum::extract::rejection::{FormRejection, JsonRejection}; -use axum::extract::{FromRequest, Request}; -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; -use axum::{async_trait, Form, Json}; -use convert_case::{Case, Casing}; -use serde::de::DeserializeOwned; -use thiserror::Error; -use validator::{Validate, ValidationErrors, ValidationErrorsKind}; - -#[derive(Debug, Error)] -pub enum ValidationError { - #[error(transparent)] - JsonRejection(#[from] JsonRejection), - #[error(transparent)] - ValidationError(#[from] ValidationErrors), -} - -fn format_messages(field: &str, errors: ValidationErrorsKind) -> Vec { - let title = field.to_case(Case::Title); - - match errors { - ValidationErrorsKind::Field(errors) => errors - .into_iter() - .map(|e| match e.code.as_ref() { - "email" => format!("{} must be a valid email address.", title), - "must_match" => format!("{} must be identical to {}.", title, e.message.unwrap()), - "length" => format!( - "{} must be between {} and {} characters long (currently {}).", - title, - e.params.get("min").unwrap(), - e.params.get("max").unwrap(), - e.params.get("value").unwrap().as_str().unwrap().len(), - ), - code => unimplemented!( - "error message is not implemented for message code '{}'", - code - ), - }) - .collect(), - ValidationErrorsKind::Struct(_) | ValidationErrorsKind::List(_) => { - panic!("unexpected error type") - } - } -} - -impl IntoResponse for ValidationError { - fn into_response(self) -> Response { - ( - StatusCode::BAD_REQUEST, - match self { - ValidationError::JsonRejection(_) => Errors::generic(self.to_string()), - ValidationError::ValidationError(validation_errors) => { - let mut errors = errors::FAILED_VALIDATION.clone(); - for (k, v) in validation_errors.into_errors() { - errors.add_specific(k.to_case(Case::Camel), format_messages(k, v)); - } - errors - } - }, - ) - .into_response() - } -} - -#[derive(Debug, Clone, Copy, Default)] -pub struct ValidatedJson(pub T); - -#[async_trait] -impl FromRequest for ValidatedJson -where - T: DeserializeOwned + Validate, - S: Send + Sync, - Form: FromRequest, -{ - type Rejection = ValidationError; - - async fn from_request(req: Request, state: &S) -> Result { - let Json(value) = Json::::from_request(req, state).await?; - value.validate()?; - Ok(ValidatedJson(value)) - } -}