Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve codegen for enums with many cases #9122

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
369 changes: 291 additions & 78 deletions crates/component-macro/src/component.rs

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions crates/component-macro/tests/expanded/simple-wasi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Errno {
#[component(name = "e")]
E,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Errno {
#[component(name = "e")]
E,
Expand Down
2 changes: 2 additions & 0 deletions crates/component-macro/tests/expanded/small-anonymous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down Expand Up @@ -207,6 +208,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down Expand Up @@ -220,6 +221,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down
4 changes: 4 additions & 0 deletions crates/component-macro/tests/expanded/variants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -313,6 +314,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down Expand Up @@ -846,6 +848,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -1109,6 +1112,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down
4 changes: 4 additions & 0 deletions crates/component-macro/tests/expanded/variants_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -320,6 +321,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down Expand Up @@ -862,6 +864,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -1125,6 +1128,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down
18 changes: 18 additions & 0 deletions crates/environ/src/component/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,24 @@ impl CanonicalAbiInfo {
}
}

/// Calculates ABI information for an enum with `cases` cases.
pub const fn enum_(cases: usize) -> CanonicalAbiInfo {
// NB: this is basically a duplicate definition of
// `CanonicalAbiInfo::variant`, these should be kept in sync.

let discrim_size = match DiscriminantSize::from_count(cases) {
Some(size) => size.byte_size(),
None => unreachable!(),
};
CanonicalAbiInfo {
size32: discrim_size,
align32: discrim_size,
size64: discrim_size,
align64: discrim_size,
flat_count: Some(1),
}
}

/// Returns the flat count of this ABI information so long as the count
/// doesn't exceed the `max` specified.
pub fn flat_count(&self, max: usize) -> Option<usize> {
Expand Down
6 changes: 6 additions & 0 deletions crates/misc/component-fuzz-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,16 @@ pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStre
.collect::<TokenStream>();

let name = make_rust_name(name_counter);
let repr = match count.ilog2() {
0..=7 => quote!(u8),
8..=15 => quote!(u16),
_ => quote!(u32),
};

declarations.extend(quote! {
#[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)]
#[component(enum)]
#[repr(#repr)]
enum #name {
#cases
}
Expand Down
2 changes: 2 additions & 0 deletions crates/wasmtime/src/runtime/component/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ pub mod __internal {
pub use anyhow;
#[cfg(feature = "async")]
pub use async_trait::async_trait;
pub use core::mem::transmute;
pub use wasmtime_environ;
pub use wasmtime_environ::component::{CanonicalAbiInfo, ComponentTypes, InterfaceType};
}
Expand Down Expand Up @@ -512,6 +513,7 @@ pub use wasmtime_component_macro::bindgen;
///
/// #[derive(ComponentType)]
/// #[component(enum)]
/// #[repr(u8)]
/// enum Setting {
/// #[component(name = "yes")]
/// Yes,
Expand Down
7 changes: 7 additions & 0 deletions crates/wit-bindgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,13 @@ impl<'a> InterfaceGenerator<'a> {
self.push_str(&derives.into_iter().collect::<Vec<_>>().join(", "));
self.push_str(")]\n");

let repr = match enum_.cases.len().ilog2() {
0..=7 => "u8",
8..=15 => "u16",
_ => "u32",
};
uwriteln!(self.src, "#[repr({repr})]");

self.push_str(&format!("pub enum {name} {{\n"));
for case in enum_.cases.iter() {
self.rustdoc(&case.docs);
Expand Down
3 changes: 3 additions & 0 deletions tests/all/component_model/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ fn variant_derive() -> Result<()> {
fn enum_derive() -> Result<()> {
#[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
#[component(enum)]
#[repr(u8)]
enum Foo {
#[component(name = "foo-bar-baz")]
A,
Expand Down Expand Up @@ -299,6 +300,7 @@ fn enum_derive() -> Result<()> {
#[add_variants(257)]
#[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
#[component(enum)]
#[repr(u16)]
enum Many {}

let component = Component::new(
Expand Down Expand Up @@ -330,6 +332,7 @@ fn enum_derive() -> Result<()> {
// #[add_variants(65537)]
// #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
// #[component(enum)]
// #[repr(u32)]
// enum ManyMore {}

Ok(())
Expand Down