Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Glotfelty committed Jan 20, 2024
2 parents 565b33e + b965277 commit df478ed
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Strum has implemented the following macros:
| [IntoStaticStr] | Implements `From<MyEnum> for &'static str` on an enum |
| [EnumVariantNames] | Adds an associated `VARIANTS` constant which is an array of discriminant names |
| [EnumIter] | Creates a new type that iterates of the variants of an enum. |
| [EnumMap] | Creates a new type that stores an item of a specified type for each variant of the enum. |
| [EnumProperty] | Add custom properties to enum variants. |
| [EnumMessage] | Add a verbose message to an enum variant. |
| [EnumDiscriminants] | Generate a new type with only the discriminant names. |
Expand Down
1 change: 1 addition & 0 deletions strum_macros/src/helpers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub use self::case_style::snakify;
pub use self::inner_variant_props::HasInnerVariantProperties;
pub use self::type_props::HasTypeProperties;
pub use self::variant_props::HasStrumVariantProperties;
Expand Down
39 changes: 39 additions & 0 deletions strum_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,45 @@ pub fn enum_try_as(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
toks.into()
}

/// Creates a new type that maps all the variants of an enum to another generic value.
///
/// This macro does not support any additional data on your variants.
/// The macro creates a new type called `YourEnumTable<T>`.
/// The table has a field of type `T` for each variant of `YourEnum`. The table automatically implements `Index<T>` and `IndexMut<T>`.
/// ```
/// use strum_macros::EnumTable;
///
/// #[derive(EnumTable)]
/// enum Color {
/// Red,
/// Yellow,
/// Green,
/// Blue,
/// }
///
/// assert_eq!(ColorTable::default(), ColorTable::new(0, 0, 0, 0));
/// assert_eq!(ColorTable::filled(2), ColorTable::new(2, 2, 2, 2));
/// assert_eq!(ColorTable::from_closure(|_| 3), ColorTable::new(3, 3, 3, 3));
/// assert_eq!(ColorTable::default().transform(|_, val| val + 2), ColorTable::new(2, 2, 2, 2));
///
/// let mut complex_map = ColorTable::from_closure(|color| match color {
/// Color::Red => 0,
/// _ => 3
/// });
/// complex_map[Color::Green] = complex_map[Color::Red];
/// assert_eq!(complex_map, ColorTable::new(0, 3, 0, 3));
///
/// ```
#[proc_macro_derive(EnumTable, attributes(strum))]
pub fn enum_table(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(input as DeriveInput);

let toks =
macros::enum_table::enum_table_inner(&ast).unwrap_or_else(|err| err.to_compile_error());
debug_print_generated(&ast, &toks);
toks.into()
}

/// Add a function to enum that allows accessing variants by its discriminant
///
/// This macro adds a standalone function to obtain an enum variant by its discriminant. The macro adds
Expand Down
204 changes: 204 additions & 0 deletions strum_macros/src/macros/enum_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::{spanned::Spanned, Data, DeriveInput, Fields};

use crate::helpers::{non_enum_error, snakify, HasStrumVariantProperties};

pub fn enum_table_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
let gen = &ast.generics;
let vis = &ast.vis;
let mut doc_comment = format!("A map over the variants of `{}`", name);

if gen.lifetimes().count() > 0 {
return Err(syn::Error::new(
Span::call_site(),
"`EnumTable` doesn't support enums with lifetimes.",
));
}

let variants = match &ast.data {
Data::Enum(v) => &v.variants,
_ => return Err(non_enum_error()),
};

let table_name = format_ident!("{}Table", name);

// the identifiers of each variant, in PascalCase
let mut pascal_idents = Vec::new();
// the identifiers of each struct field, in snake_case
let mut snake_idents = Vec::new();
// match arms in the form `MyEnumTable::Variant => &self.variant,`
let mut get_matches = Vec::new();
// match arms in the form `MyEnumTable::Variant => &mut self.variant,`
let mut get_matches_mut = Vec::new();
// match arms in the form `MyEnumTable::Variant => self.variant = new_value`
let mut set_matches = Vec::new();
// struct fields of the form `variant: func(MyEnum::Variant),*
let mut closure_fields = Vec::new();
// struct fields of the form `variant: func(MyEnum::Variant, self.variant),`
let mut transform_fields = Vec::new();

// identifiers for disabled variants
let mut disabled_variants = Vec::new();
// match arms for disabled variants
let mut disabled_matches = Vec::new();

for variant in variants {
// skip disabled variants
if variant.get_variant_properties()?.disabled.is_some() {
let disabled_ident = &variant.ident;
let panic_message = format!(
"Can't use `{}` with `{}` - variant is disabled for Strum features",
disabled_ident, table_name
);
disabled_variants.push(disabled_ident);
disabled_matches.push(quote!(#name::#disabled_ident => panic!(#panic_message),));
continue;
}

// Error on variants with data
if variant.fields != Fields::Unit {
return Err(syn::Error::new(
variant.fields.span(),
"`EnumTable` doesn't support enums with non-unit variants",
));
};

let pascal_case = &variant.ident;
let snake_case = format_ident!("_{}", snakify(&pascal_case.to_string()));

get_matches.push(quote! {#name::#pascal_case => &self.#snake_case,});
get_matches_mut.push(quote! {#name::#pascal_case => &mut self.#snake_case,});
set_matches.push(quote! {#name::#pascal_case => self.#snake_case = new_value,});
closure_fields.push(quote! {#snake_case: func(#name::#pascal_case),});
transform_fields.push(quote! {#snake_case: func(#name::#pascal_case, &self.#snake_case),});
pascal_idents.push(pascal_case);
snake_idents.push(snake_case);
}

// Error on empty enums
if pascal_idents.is_empty() {
return Err(syn::Error::new(
variants.span(),
"`EnumTable` requires at least one non-disabled variant",
));
}

// if the index operation can panic, add that to the documentation
if !disabled_variants.is_empty() {
doc_comment.push_str(&format!(
"\n# Panics\nIndexing `{}` with any of the following variants will cause a panic:",
table_name
));
for variant in disabled_variants {
doc_comment.push_str(&format!("\n\n- `{}::{}`", name, variant));
}
}

let doc_new = format!(
"Create a new {} with a value for each variant of {}",
table_name, name
);
let doc_closure = format!(
"Create a new {} by running a function on each variant of `{}`",
table_name, name
);
let doc_transform = format!("Create a new `{}` by running a function on each variant of `{}` and the corresponding value in the current `{0}`", table_name, name);
let doc_filled = format!(
"Create a new `{}` with the same value in each field.",
table_name
);
let doc_option_all = format!("Converts `{}<Option<T>>` into `Option<{0}<T>>`. Returns `Some` if all fields are `Some`, otherwise returns `None`.", table_name);
let doc_result_all_ok = format!("Converts `{}<Result<T, E>>` into `Result<{0}<T>, E>`. Returns `Ok` if all fields are `Ok`, otherwise returns `Err`.", table_name);

Ok(quote! {
#[doc = #doc_comment]
#[allow(
missing_copy_implementations,
)]
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
#vis struct #table_name<T> {
#(#snake_idents: T,)*
}

impl<T: Clone> #table_name<T> {
#[doc = #doc_filled]
#vis fn filled(value: T) -> #table_name<T> {
#table_name {
#(#snake_idents: value.clone(),)*
}
}
}

impl<T> #table_name<T> {
#[doc = #doc_new]
#vis fn new(
#(#snake_idents: T,)*
) -> #table_name<T> {
#table_name {
#(#snake_idents,)*
}
}

#[doc = #doc_closure]
#vis fn from_closure<F: Fn(#name)->T>(func: F) -> #table_name<T> {
#table_name {
#(#closure_fields)*
}
}

#[doc = #doc_transform]
#vis fn transform<U, F: Fn(#name, &T)->U>(&self, func: F) -> #table_name<U> {
#table_name {
#(#transform_fields)*
}
}

}

impl<T> ::core::ops::Index<#name> for #table_name<T> {
type Output = T;

fn index(&self, idx: #name) -> &T {
match idx {
#(#get_matches)*
#(#disabled_matches)*
}
}
}

impl<T> ::core::ops::IndexMut<#name> for #table_name<T> {
fn index_mut(&mut self, idx: #name) -> &mut T {
match idx {
#(#get_matches_mut)*
#(#disabled_matches)*
}
}
}

impl<T> #table_name<::core::option::Option<T>> {
#[doc = #doc_option_all]
#vis fn all(self) -> ::core::option::Option<#table_name<T>> {
if let #table_name {
#(#snake_idents: ::core::option::Option::Some(#snake_idents),)*
} = self {
::core::option::Option::Some(#table_name {
#(#snake_idents,)*
})
} else {
::core::option::Option::None
}
}
}

impl<T, E> #table_name<::core::result::Result<T, E>> {
#[doc = #doc_result_all_ok]
#vis fn all_ok(self) -> ::core::result::Result<#table_name<T>, E> {
::core::result::Result::Ok(#table_name {
#(#snake_idents: self.#snake_idents?,)*
})
}
}
})
}
1 change: 1 addition & 0 deletions strum_macros/src/macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod enum_count;
pub mod enum_discriminants;
pub mod enum_is;
pub mod enum_iter;
pub mod enum_table;
pub mod enum_messages;
pub mod enum_properties;
pub mod enum_try_as;
Expand Down
101 changes: 101 additions & 0 deletions strum_tests/tests/enum_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use strum::EnumTable;

#[derive(EnumTable)]
enum Color {
Red,
Yellow,
Green,
#[strum(disabled)]
Teal,
Blue,
#[strum(disabled)]
Indigo,
}

// even though this isn't used, it needs to be a test
// because if it doesn't compile, enum variants that conflict with keywords won't work
#[derive(EnumTable)]
enum Keyword {
Const,
}

#[test]
fn default() {
assert_eq!(ColorTable::default(), ColorTable::new(0, 0, 0, 0));
}

#[test]
#[should_panic]
fn disabled() {
let _ = ColorTable::<u8>::default()[Color::Indigo];
}

#[test]
fn filled() {
assert_eq!(ColorTable::filled(42), ColorTable::new(42, 42, 42, 42));
}

#[test]
fn from_closure() {
assert_eq!(
ColorTable::from_closure(|color| match color {
Color::Red => 1,
_ => 2,
}),
ColorTable::new(1, 2, 2, 2)
);
}

#[test]
fn clone() {
let cm = ColorTable::filled(String::from("Some Text Data"));
assert_eq!(cm.clone(), cm);
}

#[test]
fn index() {
let map = ColorTable::new(18, 25, 7, 2);
assert_eq!(map[Color::Red], 18);
assert_eq!(map[Color::Yellow], 25);
assert_eq!(map[Color::Green], 7);
assert_eq!(map[Color::Blue], 2);
}

#[test]
fn index_mut() {
let mut map = ColorTable::new(18, 25, 7, 2);
map[Color::Green] = 5;
map[Color::Red] *= 4;
assert_eq!(map[Color::Green], 5);
assert_eq!(map[Color::Red], 72);
}

#[test]
fn option_all() {
let mut map: ColorTable<Option<u8>> = ColorTable::filled(None);
map[Color::Red] = Some(64);
map[Color::Green] = Some(32);
map[Color::Blue] = Some(16);

assert_eq!(map.clone().all(), None);

map[Color::Yellow] = Some(8);
assert_eq!(map.all(), Some(ColorTable::new(64, 8, 32, 16)));
}

#[test]
fn result_all_ok() {
let mut map: ColorTable<Result<u8, u8>> = ColorTable::filled(Ok(4));
assert_eq!(map.clone().all_ok(), Ok(ColorTable::filled(4)));
map[Color::Red] = Err(22);
map[Color::Yellow] = Err(100);
assert_eq!(map.clone().all_ok(), Err(22));
map[Color::Red] = Ok(1);
assert_eq!(map.all_ok(), Err(100));
}

#[test]
fn transform() {
let all_two = ColorTable::filled(2);
assert_eq!(all_two.transform(|_, n| *n * 2), ColorTable::filled(4));
}

0 comments on commit df478ed

Please sign in to comment.