Skip to content

Commit

Permalink
Improve codegen for enums with many cases (#9122)
Browse files Browse the repository at this point in the history
* Improve codegen for enums with many cases

This commit improves the compile time of generating bindings for enums
with many cases in them (e.g. 1000+). This is done by optimizing for
enums specifically rather than handling them generically like other
variants which can reduce the amount of code going into rustc to O(1)
instead of O(N) with the number of cases. This in turn can greatly
reduce compile time.

The tradeoff made in this commit is that enums are now required to have
`#[repr(...)]` annotations along with no Rust-level discriminants
specified. This enables the use of a `transmute` to lift a discriminant
into Rust with a simple bounds check. Previously this was one large
`match` statement.

Closes #9081

* Fix some tests

* Add repr tag in fuzzing

* Fix syntax for Rust 1.78
  • Loading branch information
alexcrichton authored Aug 12, 2024
1 parent c2cdfee commit fa41d13
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 78 deletions.
369 changes: 291 additions & 78 deletions crates/component-macro/src/component.rs

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions crates/component-macro/tests/expanded/simple-wasi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Errno {
#[component(name = "e")]
E,
Expand Down
1 change: 1 addition & 0 deletions crates/component-macro/tests/expanded/simple-wasi_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Errno {
#[component(name = "e")]
E,
Expand Down
2 changes: 2 additions & 0 deletions crates/component-macro/tests/expanded/small-anonymous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down Expand Up @@ -207,6 +208,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down Expand Up @@ -220,6 +221,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum Error {
#[component(name = "success")]
Success,
Expand Down
4 changes: 4 additions & 0 deletions crates/component-macro/tests/expanded/variants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -313,6 +314,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down Expand Up @@ -846,6 +848,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -1109,6 +1112,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down
4 changes: 4 additions & 0 deletions crates/component-macro/tests/expanded/variants_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -320,6 +321,7 @@ pub mod foo {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down Expand Up @@ -862,6 +864,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum E1 {
#[component(name = "a")]
A,
Expand Down Expand Up @@ -1125,6 +1128,7 @@ pub mod exports {
#[derive(wasmtime::component::Lower)]
#[component(enum)]
#[derive(Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum MyErrno {
#[component(name = "bad1")]
Bad1,
Expand Down
18 changes: 18 additions & 0 deletions crates/environ/src/component/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,24 @@ impl CanonicalAbiInfo {
}
}

/// Calculates ABI information for an enum with `cases` cases.
pub const fn enum_(cases: usize) -> CanonicalAbiInfo {
// NB: this is basically a duplicate definition of
// `CanonicalAbiInfo::variant`, these should be kept in sync.

let discrim_size = match DiscriminantSize::from_count(cases) {
Some(size) => size.byte_size(),
None => unreachable!(),
};
CanonicalAbiInfo {
size32: discrim_size,
align32: discrim_size,
size64: discrim_size,
align64: discrim_size,
flat_count: Some(1),
}
}

/// Returns the flat count of this ABI information so long as the count
/// doesn't exceed the `max` specified.
pub fn flat_count(&self, max: usize) -> Option<usize> {
Expand Down
6 changes: 6 additions & 0 deletions crates/misc/component-fuzz-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,16 @@ pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStre
.collect::<TokenStream>();

let name = make_rust_name(name_counter);
let repr = match count.ilog2() {
0..=7 => quote!(u8),
8..=15 => quote!(u16),
_ => quote!(u32),
};

declarations.extend(quote! {
#[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)]
#[component(enum)]
#[repr(#repr)]
enum #name {
#cases
}
Expand Down
2 changes: 2 additions & 0 deletions crates/wasmtime/src/runtime/component/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ pub mod __internal {
pub use anyhow;
#[cfg(feature = "async")]
pub use async_trait::async_trait;
pub use core::mem::transmute;
pub use wasmtime_environ;
pub use wasmtime_environ::component::{CanonicalAbiInfo, ComponentTypes, InterfaceType};
}
Expand Down Expand Up @@ -512,6 +513,7 @@ pub use wasmtime_component_macro::bindgen;
///
/// #[derive(ComponentType)]
/// #[component(enum)]
/// #[repr(u8)]
/// enum Setting {
/// #[component(name = "yes")]
/// Yes,
Expand Down
7 changes: 7 additions & 0 deletions crates/wit-bindgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,13 @@ impl<'a> InterfaceGenerator<'a> {
self.push_str(&derives.into_iter().collect::<Vec<_>>().join(", "));
self.push_str(")]\n");

let repr = match enum_.cases.len().ilog2() {
0..=7 => "u8",
8..=15 => "u16",
_ => "u32",
};
uwriteln!(self.src, "#[repr({repr})]");

self.push_str(&format!("pub enum {name} {{\n"));
for case in enum_.cases.iter() {
self.rustdoc(&case.docs);
Expand Down
3 changes: 3 additions & 0 deletions tests/all/component_model/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ fn variant_derive() -> Result<()> {
fn enum_derive() -> Result<()> {
#[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
#[component(enum)]
#[repr(u8)]
enum Foo {
#[component(name = "foo-bar-baz")]
A,
Expand Down Expand Up @@ -299,6 +300,7 @@ fn enum_derive() -> Result<()> {
#[add_variants(257)]
#[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
#[component(enum)]
#[repr(u16)]
enum Many {}

let component = Component::new(
Expand Down Expand Up @@ -330,6 +332,7 @@ fn enum_derive() -> Result<()> {
// #[add_variants(65537)]
// #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)]
// #[component(enum)]
// #[repr(u32)]
// enum ManyMore {}

Ok(())
Expand Down

0 comments on commit fa41d13

Please sign in to comment.