From ea7b66624befc9f923aaacfef862ad1023929448 Mon Sep 17 00:00:00 2001 From: pintariching <64165058+pintariching@users.noreply.github.com> Date: Fri, 14 Apr 2023 12:50:00 +0200 Subject: [PATCH] Trait validation (#225) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implemented validation trait for length * converted identation to spaces * changed the trait to not require HasLen * added macro for generating impls * implemented ValidateLength for some types * using trait validation instead of the function * added cfg for indexmap import * changed trait to require length * Revert "changed trait to require length" This reverts commit a77bdc9297a65f9eb3dfa345c92e7cb1aee2525f. * moved validation logic inside ValidateLength trait * added trait validation for required * added email trait validation * fixed trait validation for email * added range trait validation * fixed range trait * added url trait validation --------- Co-authored-by: Tilen Pintarič --- validator/src/lib.rs | 10 +- validator/src/validation/email.rs | 97 +++++++---- validator/src/validation/length.rs | 202 ++++++++++++++++++++--- validator/src/validation/range.rs | 47 ++++-- validator/src/validation/required.rs | 16 +- validator/src/validation/urls.rs | 39 ++++- validator_derive/src/lib.rs | 2 - validator_derive_tests/tests/email.rs | 35 ++++ validator_derive_tests/tests/length.rs | 68 ++++++++ validator_derive_tests/tests/required.rs | 30 ++++ validator_derive_tests/tests/url.rs | 50 ++++++ 11 files changed, 520 insertions(+), 76 deletions(-) diff --git a/validator/src/lib.rs b/validator/src/lib.rs index ca71b3b2..e4fc52da 100644 --- a/validator/src/lib.rs +++ b/validator/src/lib.rs @@ -73,18 +73,18 @@ mod validation; pub use validation::cards::validate_credit_card; pub use validation::contains::validate_contains; pub use validation::does_not_contain::validate_does_not_contain; -pub use validation::email::validate_email; +pub use validation::email::{validate_email, ValidateEmail}; pub use validation::ip::{validate_ip, validate_ip_v4, validate_ip_v6}; -pub use validation::length::validate_length; +pub use validation::length::{validate_length, ValidateLength}; pub use validation::must_match::validate_must_match; #[cfg(feature = "unic")] pub use validation::non_control_character::validate_non_control_character; #[cfg(feature = "phone")] pub use validation::phone::validate_phone; -pub use validation::range::validate_range; +pub use validation::range::{validate_range, ValidateRange}; -pub use validation::required::validate_required; -pub use validation::urls::validate_url; +pub use validation::required::{validate_required, ValidateRequired}; +pub use validation::urls::{validate_url, ValidateUrl}; pub use traits::{Contains, HasLen, Validate, ValidateArgs}; pub use types::{ValidationError, ValidationErrors, ValidationErrorsKind}; diff --git a/validator/src/validation/email.rs b/validator/src/validation/email.rs index e5aab3b6..38b55f8a 100644 --- a/validator/src/validation/email.rs +++ b/validator/src/validation/email.rs @@ -21,39 +21,8 @@ lazy_static! { /// [RFC 5322](https://tools.ietf.org/html/rfc5322) is not practical in most circumstances and allows email addresses /// that are unfamiliar to most users. #[must_use] -pub fn validate_email<'a, T>(val: T) -> bool -where - T: Into>, -{ - let val = val.into(); - if val.is_empty() || !val.contains('@') { - return false; - } - let parts: Vec<&str> = val.rsplitn(2, '@').collect(); - let user_part = parts[1]; - let domain_part = parts[0]; - - // validate the length of each part of the email, BEFORE doing the regex - // according to RFC5321 the max length of the local part is 64 characters - // and the max length of the domain part is 255 characters - // https://datatracker.ietf.org/doc/html/rfc5321#section-4.5.3.1.1 - if user_part.length() > 64 || domain_part.length() > 255 { - return false; - } - - if !EMAIL_USER_RE.is_match(user_part) { - return false; - } - - if !validate_domain_part(domain_part) { - // Still the possibility of an [IDN](https://en.wikipedia.org/wiki/Internationalized_domain_name) - return match domain_to_ascii(domain_part) { - Ok(d) => validate_domain_part(&d), - Err(_) => false, - }; - } - - true +pub fn validate_email(val: T) -> bool { + val.validate_email() } /// Checks if the domain is a valid domain and if not, check whether it's an IP @@ -73,6 +42,68 @@ fn validate_domain_part(domain_part: &str) -> bool { } } +pub trait ValidateEmail { + fn validate_email(&self) -> bool { + let val = self.to_email_string(); + + if val.is_empty() || !val.contains('@') { + return false; + } + + let parts: Vec<&str> = val.rsplitn(2, '@').collect(); + let user_part = parts[1]; + let domain_part = parts[0]; + + // validate the length of each part of the email, BEFORE doing the regex + // according to RFC5321 the max length of the local part is 64 characters + // and the max length of the domain part is 255 characters + // https://datatracker.ietf.org/doc/html/rfc5321#section-4.5.3.1.1 + if user_part.length() > 64 || domain_part.length() > 255 { + return false; + } + + if !EMAIL_USER_RE.is_match(user_part) { + return false; + } + + if !validate_domain_part(domain_part) { + // Still the possibility of an [IDN](https://en.wikipedia.org/wiki/Internationalized_domain_name) + return match domain_to_ascii(domain_part) { + Ok(d) => validate_domain_part(&d), + Err(_) => false, + }; + } + + true + } + + fn to_email_string<'a>(&'a self) -> Cow<'a, str>; +} + +impl ValidateEmail for &str { + fn to_email_string(&self) -> Cow<'_, str> { + Cow::from(*self) + } +} + +impl ValidateEmail for String { + fn to_email_string(&self) -> Cow<'_, str> { + Cow::from(self) + } +} + +impl ValidateEmail for &String { + fn to_email_string(&self) -> Cow<'_, str> { + Cow::from(*self) + } +} + +impl ValidateEmail for Cow<'_, str> { + fn to_email_string(&self) -> Cow<'_, str> { + self.clone() + } +} + #[cfg(test)] mod tests { use std::borrow::Cow; diff --git a/validator/src/validation/length.rs b/validator/src/validation/length.rs index 759c6810..3f2b0075 100644 --- a/validator/src/validation/length.rs +++ b/validator/src/validation/length.rs @@ -1,4 +1,7 @@ -use crate::traits::HasLen; +use std::{borrow::Cow, collections::{HashMap, HashSet, BTreeMap, BTreeSet}}; + +#[cfg(feature = "indexmap")] +use indexmap::{IndexMap, IndexSet}; /// Validates the length of the value given. /// If the validator has `equal` set, it will ignore any `min` and `max` value. @@ -6,37 +9,156 @@ use crate::traits::HasLen; /// If you apply it on String, don't forget that the length can be different /// from the number of visual characters for Unicode #[must_use] -pub fn validate_length( +pub fn validate_length( value: T, min: Option, max: Option, equal: Option, ) -> bool { - let val_length = value.length(); - - if let Some(eq) = equal { - return val_length == eq; - } else { - if let Some(m) = min { - if val_length < m { - return false; - } - } - if let Some(m) = max { - if val_length > m { - return false; - } - } - } + value.validate_length(min, max, equal) +} + +pub trait ValidateLength { + fn validate_length(&self, min: Option, max: Option, equal: Option) -> bool { + let length = self.length(); + + if let Some(eq) = equal { + return length == eq; + } else { + if let Some(m) = min { + if length < m { + return false; + } + } + if let Some(m) = max { + if length > m { + return false; + } + } + } + + true + } + + fn length(&self) -> u64; +} + +impl ValidateLength for String { + fn length(&self) -> u64 { + self.chars().count() as u64 + } +} + +impl<'a> ValidateLength for &'a String { + fn length(&self) -> u64 { + self.chars().count() as u64 + } +} + +impl<'a> ValidateLength for &'a str { + fn length(&self) -> u64 { + self.chars().count() as u64 + } +} + +impl<'a> ValidateLength for Cow<'a, str> { + fn length(&self) -> u64 { + self.chars().count() as u64 + } +} + +impl ValidateLength for Vec { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +impl<'a, T> ValidateLength for &'a Vec { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +impl ValidateLength for &[T] { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +impl ValidateLength for [T; N] { + fn length(&self) -> u64 { + N as u64 + } +} + +impl ValidateLength for &[T; N] { + fn length(&self) -> u64 { + N as u64 + } +} + +impl<'a, K, V, S> ValidateLength for &'a HashMap { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +impl ValidateLength for HashMap { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +impl<'a, T, S> ValidateLength for &'a HashSet { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +impl<'a, K, V> ValidateLength for &'a BTreeMap { + fn length(&self) -> u64 { + self.len() as u64 + } +} - true +impl<'a, T> ValidateLength for &'a BTreeSet { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +impl ValidateLength for BTreeSet { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +#[cfg(feature = "indexmap")] +impl<'a, K, V> ValidateLength for &'a IndexMap { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +#[cfg(feature = "indexmap")] +impl<'a, T> ValidateLength for &'a IndexSet { + fn length(&self) -> u64 { + self.len() as u64 + } +} + +#[cfg(feature = "indexmap")] +impl ValidateLength for IndexSet { + fn length(&self) -> u64 { + self.len() as u64 + } } #[cfg(test)] mod tests { use std::borrow::Cow; - use super::validate_length; + use crate::{validate_length, validation::length::ValidateLength}; #[test] fn test_validate_length_equal_overrides_min_max() { @@ -76,4 +198,44 @@ mod tests { fn test_validate_length_unicode_chars() { assert!(validate_length("日本", None, None, Some(2))); } + + + #[test] + fn test_validate_length_trait_equal_overrides_min_max() { + assert!(String::from("hello").validate_length(Some(1), Some(2), Some(5))); + } + + #[test] + fn test_validate_length_trait_string_min_max() { + assert!(String::from("hello").validate_length(Some(1), Some(10), None)); + } + + #[test] + fn test_validate_length_trait_string_min_only() { + assert!(!String::from("hello").validate_length(Some(10), None, None)); + } + + #[test] + fn test_validate_length_trait_string_max_only() { + assert!(!String::from("hello").validate_length(None, Some(1), None)); + } + + #[test] + fn test_validate_length_trait_cow() { + let test: Cow<'static, str> = "hello".into(); + assert!(test.validate_length(None, None, Some(5))); + + let test: Cow<'static, str> = String::from("hello").into(); + assert!(test.validate_length(None, None, Some(5))); + } + + #[test] + fn test_validate_length_trait_vec() { + assert!(vec![1, 2, 3].validate_length(None, None, Some(3))); + } + + #[test] + fn test_validate_length_trait_unicode_chars() { + assert!(String::from("日本").validate_length(None, None, Some(2))); + } } diff --git a/validator/src/validation/range.rs b/validator/src/validation/range.rs index 3b3cf6ae..8f339caa 100644 --- a/validator/src/validation/range.rs +++ b/validator/src/validation/range.rs @@ -2,23 +2,50 @@ /// optional and will only be validated if they are not `None` /// #[must_use] -pub fn validate_range(value: T, min: Option, max: Option) -> bool +pub fn validate_range>(value: T, min: Option, max: Option) -> bool { + value.validate_range(min, max) +} + +pub trait ValidateRange { + fn validate_range(&self, min: Option, max: Option) -> bool { + if let Some(max) = max { + if self.greater_than(max) { + return false; + } + } + + if let Some(min) = min { + if self.less_than(min) { + return false; + } + } + + true + } + + fn greater_than(&self, max: T) -> bool; + fn less_than(&self, min: T) -> bool; +} + +impl ValidateRange for T where - T: PartialOrd + PartialEq, + T: PartialEq + PartialOrd, { - if let Some(max) = max { - if value > max { - return false; + fn greater_than(&self, max: T) -> bool { + if self > &max { + return true; } + + false } - if let Some(min) = min { - if value < min { - return false; + fn less_than(&self, min: T) -> bool { + if self < &min { + return true; } - } - true + false + } } #[cfg(test)] diff --git a/validator/src/validation/required.rs b/validator/src/validation/required.rs index 80b06b04..863aeee3 100644 --- a/validator/src/validation/required.rs +++ b/validator/src/validation/required.rs @@ -1,5 +1,19 @@ /// Validates whether the given Option is Some #[must_use] -pub fn validate_required(val: &Option) -> bool { +pub fn validate_required(val: &T) -> bool { val.is_some() } + +pub trait ValidateRequired { + fn validate_required(&self) -> bool { + self.is_some() + } + + fn is_some(&self) -> bool; +} + +impl ValidateRequired for Option { + fn is_some(&self) -> bool { + self.is_some() + } +} diff --git a/validator/src/validation/urls.rs b/validator/src/validation/urls.rs index 41774d1a..1528dcd4 100644 --- a/validator/src/validation/urls.rs +++ b/validator/src/validation/urls.rs @@ -3,11 +3,40 @@ use url::Url; /// Validates whether the string given is a url #[must_use] -pub fn validate_url<'a, T>(val: T) -> bool -where - T: Into>, -{ - Url::parse(val.into().as_ref()).is_ok() +pub fn validate_url(val: T) -> bool { + val.validate_url() +} + +pub trait ValidateUrl { + fn validate_url(&self) -> bool { + Url::parse(&self.to_url_string()).is_ok() + } + + fn to_url_string<'a>(&'a self) -> Cow<'a, str>; +} + +impl ValidateUrl for &str { + fn to_url_string(&self) -> Cow<'_, str> { + Cow::from(*self) + } +} + +impl ValidateUrl for String { + fn to_url_string(&self) -> Cow<'_, str> { + Cow::from(self) + } +} + +impl ValidateUrl for &String { + fn to_url_string(&self) -> Cow<'_, str> { + Cow::from(*self) + } +} + +impl ValidateUrl for Cow<'_, str> { + fn to_url_string(&self) -> Cow<'_, str> { + self.clone() + } } #[cfg(test)] diff --git a/validator_derive/src/lib.rs b/validator_derive/src/lib.rs index fdecfbd1..143fbd00 100644 --- a/validator_derive/src/lib.rs +++ b/validator_derive/src/lib.rs @@ -408,11 +408,9 @@ fn find_validators_for_field( syn::Meta::Path(ref name) => { match name.get_ident().unwrap().to_string().as_ref() { "email" => { - assert_string_type("email", field_type, &field.ty); validators.push(FieldValidation::new(Validator::Email)); } "url" => { - assert_string_type("url", field_type, &field.ty); validators.push(FieldValidation::new(Validator::Url)); } #[cfg(feature = "phone")] diff --git a/validator_derive_tests/tests/email.rs b/validator_derive_tests/tests/email.rs index 694b357f..0eaf4ca6 100644 --- a/validator_derive_tests/tests/email.rs +++ b/validator_derive_tests/tests/email.rs @@ -1,3 +1,4 @@ +use serde::Serialize; use validator::Validate; #[test] @@ -65,3 +66,37 @@ fn can_specify_message_for_email() { assert_eq!(errs["val"].len(), 1); assert_eq!(errs["val"][0].clone().message.unwrap(), "oops"); } + +#[test] +fn can_validate_custom_impl_for_email() { + use std::borrow::Cow; + + #[derive(Debug, Serialize)] + struct CustomEmail { + user_part: String, + domain_part: String, + } + + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(email)] + val: CustomEmail, + } + + impl validator::ValidateEmail for &CustomEmail { + fn to_email_string(&self) -> Cow<'_, str> { + Cow::from(format!("{}@{}", self.user_part, self.domain_part)) + } + } + + let valid = TestStruct { + val: CustomEmail { user_part: "username".to_string(), domain_part: "gmail.com".to_owned() }, + }; + + let invalid = TestStruct { + val: CustomEmail { user_part: "abc".to_string(), domain_part: "".to_owned() }, + }; + + assert!(valid.validate().is_ok()); + assert!(invalid.validate().is_err()); +} diff --git a/validator_derive_tests/tests/length.rs b/validator_derive_tests/tests/length.rs index 317b2999..aa5de1fe 100644 --- a/validator_derive_tests/tests/length.rs +++ b/validator_derive_tests/tests/length.rs @@ -222,3 +222,71 @@ fn can_validate_set_ref_for_length() { assert_eq!(errs["val"][0].params["min"], 5); assert_eq!(errs["val"][0].params["max"], 10); } + +#[test] +fn can_validate_custom_impl_for_length() { + use serde::Serialize; + + #[derive(Debug, Serialize)] + struct CustomString(String); + + impl validator::ValidateLength for &CustomString { + fn validate_length(&self, min: Option, max: Option, equal: Option) -> bool { + let length = self.length(); + + if let Some(eq) = equal { + return length == eq; + } else { + if let Some(m) = min { + if length < m { + return false; + } + } + if let Some(m) = max { + if length > m { + return false; + } + } + } + + true + } + + fn length(&self) -> u64 { + self.0.chars().count() as u64 + } + } + + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(length(min = 5, max = 10))] + val: CustomString, + } + + #[derive(Debug, Validate)] + struct EqualsTestStruct { + #[validate(length(equal = 11))] + val: CustomString + } + + let too_short = TestStruct { + val: CustomString(String::from("oops")) + }; + + let too_long = TestStruct { + val: CustomString(String::from("too long for this")) + }; + + let ok = TestStruct { + val: CustomString(String::from("perfect")) + }; + + let equals_ok = EqualsTestStruct { + val: CustomString(String::from("just enough")) + }; + + assert!(too_short.validate().is_err()); + assert!(too_long.validate().is_err()); + assert!(ok.validate().is_ok()); + assert!(equals_ok.validate().is_ok()); +} \ No newline at end of file diff --git a/validator_derive_tests/tests/required.rs b/validator_derive_tests/tests/required.rs index f03d6653..c9f751b1 100644 --- a/validator_derive_tests/tests/required.rs +++ b/validator_derive_tests/tests/required.rs @@ -90,3 +90,33 @@ fn can_specify_message_for_required() { assert_eq!(errs["val"].len(), 1); assert_eq!(errs["val"][0].clone().message.unwrap(), "oops"); } + +#[test] +fn can_validate_custom_impl_for_required() { + #[derive(Debug, Serialize)] + enum CustomOption { + Something(T), + Nothing, + } + + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(required)] + val: CustomOption, + } + + impl validator::ValidateRequired for CustomOption { + fn is_some(&self) -> bool { + match self { + CustomOption::Something(_) => true, + CustomOption::Nothing => false, + } + } + } + + let something = TestStruct { val: CustomOption::Something("this is something".to_string()) }; + let nothing = TestStruct { val: CustomOption::Nothing }; + + assert!(something.validate().is_ok()); + assert!(nothing.validate().is_err()); +} \ No newline at end of file diff --git a/validator_derive_tests/tests/url.rs b/validator_derive_tests/tests/url.rs index 88d2f48b..fe8d4888 100644 --- a/validator_derive_tests/tests/url.rs +++ b/validator_derive_tests/tests/url.rs @@ -67,3 +67,53 @@ fn can_specify_message_for_url() { assert_eq!(errs["val"].len(), 1); assert_eq!(errs["val"][0].clone().message.unwrap(), "oops"); } + +#[test] +fn can_validate_custom_impl_for_url() { + use serde::Serialize; + use std::borrow::Cow; + + #[derive(Debug, Serialize)] + struct CustomUrl { + scheme: String, + subdomain: String, + domain: String, + top_level_domain: String, + } + + #[derive(Debug, Validate)] + struct TestStruct { + #[validate(url)] + val: CustomUrl, + } + + impl validator::ValidateUrl for &CustomUrl { + fn to_url_string(&self) -> Cow<'_, str> { + Cow::from(format!( + "{}://{}.{}.{}", + self.scheme, self.subdomain, self.domain, self.top_level_domain + )) + } + } + + let valid = TestStruct { + val: CustomUrl { + scheme: "http".to_string(), + subdomain: "www".to_string(), + domain: "google".to_string(), + top_level_domain: "com".to_string(), + }, + }; + + let invalid = TestStruct { + val: CustomUrl { + scheme: "".to_string(), + subdomain: "".to_string(), + domain: "google".to_string(), + top_level_domain: "".to_string(), + }, + }; + + assert!(valid.validate().is_ok()); + assert!(invalid.validate().is_err()); +}