diff --git a/crates/wiggle/generate/src/codegen_settings.rs b/crates/wiggle/generate/src/codegen_settings.rs index 616a12dd7ea7..11391ed71f1d 100644 --- a/crates/wiggle/generate/src/codegen_settings.rs +++ b/crates/wiggle/generate/src/codegen_settings.rs @@ -1,6 +1,6 @@ -use crate::config::{AsyncConf, ErrorConf, TracingConf}; +use crate::config::{AsyncConf, ErrorConf, ErrorConfField, TracingConf}; use anyhow::{anyhow, Error}; -use proc_macro2::TokenStream; +use proc_macro2::{Ident, TokenStream}; use quote::quote; use std::collections::HashMap; use std::rc::Rc; @@ -39,7 +39,7 @@ impl CodegenSettings { } pub struct ErrorTransform { - m: Vec, + m: Vec, } impl ErrorTransform { @@ -49,7 +49,13 @@ impl ErrorTransform { pub fn new(conf: &ErrorConf, doc: &Document) -> Result { let mut richtype_identifiers = HashMap::new(); let m = conf.iter().map(|(ident, field)| - if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) { + match field { + ErrorConfField::Trappable(field) => if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) { + Ok(ErrorType::Generated(TrappableErrorType { abi_type, rich_type: field.rich_error.clone() })) + } else { + Err(anyhow!("No witx typename \"{}\" found", ident.to_string())) + }, + ErrorConfField::User(field) => if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) { if let Some(ident) = field.rich_error.get_ident() { if let Some(prior_def) = richtype_identifiers.insert(ident.clone(), field.err_loc.clone()) { @@ -58,11 +64,11 @@ impl ErrorTransform { ident, prior_def )); } - Ok(UserErrorType { + Ok(ErrorType::User(UserErrorType { abi_type, rich_type: field.rich_error.clone(), method_fragment: ident.to_string() - }) + })) } else { return Err(anyhow!( "rich error type must be identifier for now - TODO add ability to provide a corresponding identifier: {:?}", @@ -71,23 +77,52 @@ impl ErrorTransform { } } else { Err(anyhow!("No witx typename \"{}\" found", ident.to_string())) } + } ).collect::, Error>>()?; Ok(Self { m }) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.m.iter() } - pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&UserErrorType> { + pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&ErrorType> { match tref { TypeRef::Name(nt) => self.for_name(nt), TypeRef::Value { .. } => None, } } - pub fn for_name(&self, nt: &NamedType) -> Option<&UserErrorType> { - self.m.iter().find(|u| u.abi_type.name == nt.name) + pub fn for_name(&self, nt: &NamedType) -> Option<&ErrorType> { + self.m.iter().find(|e| e.abi_type().name == nt.name) + } +} + +pub enum ErrorType { + User(UserErrorType), + Generated(TrappableErrorType), +} +impl ErrorType { + pub fn abi_type(&self) -> &NamedType { + match self { + Self::User(u) => &u.abi_type, + Self::Generated(r) => &r.abi_type, + } + } +} + +pub struct TrappableErrorType { + abi_type: Rc, + rich_type: Ident, +} + +impl TrappableErrorType { + pub fn abi_type(&self) -> TypeRef { + TypeRef::Name(self.abi_type.clone()) + } + pub fn typename(&self) -> TokenStream { + let richtype = &self.rich_type; + quote!(#richtype) } } diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index 1eedc1840ec4..2e42a93ef5e7 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -27,6 +27,7 @@ mod kw { syn::custom_keyword!(wasmtime); syn::custom_keyword!(tracing); syn::custom_keyword!(disable_for); + syn::custom_keyword!(trappable); } #[derive(Debug, Clone)] @@ -274,14 +275,14 @@ impl Parse for ErrorConf { content.parse_terminated(Parse::parse)?; let mut m = HashMap::new(); for i in items { - match m.insert(i.abi_error.clone(), i.clone()) { + match m.insert(i.abi_error().clone(), i.clone()) { None => {} Some(prev_def) => { return Err(Error::new( - i.err_loc, + *i.err_loc(), format!( "duplicate definition of rich error type for {:?}: previously defined at {:?}", - i.abi_error, prev_def.err_loc, + i.abi_error(), prev_def.err_loc(), ), )) } @@ -291,14 +292,67 @@ impl Parse for ErrorConf { } } +#[derive(Debug, Clone)] +pub enum ErrorConfField { + Trappable(TrappableErrorConfField), + User(UserErrorConfField), +} +impl ErrorConfField { + pub fn abi_error(&self) -> &Ident { + match self { + Self::Trappable(t) => &t.abi_error, + Self::User(u) => &u.abi_error, + } + } + pub fn err_loc(&self) -> &Span { + match self { + Self::Trappable(t) => &t.err_loc, + Self::User(u) => &u.err_loc, + } + } +} + +impl Parse for ErrorConfField { + fn parse(input: ParseStream) -> Result { + let err_loc = input.span(); + let abi_error = input.parse::()?; + let _arrow: Token![=>] = input.parse()?; + + let lookahead = input.lookahead1(); + if lookahead.peek(kw::trappable) { + let _ = input.parse::()?; + let rich_error = input.parse()?; + Ok(ErrorConfField::Trappable(TrappableErrorConfField { + abi_error, + rich_error, + err_loc, + })) + } else { + let rich_error = input.parse::()?; + Ok(ErrorConfField::User(UserErrorConfField { + abi_error, + rich_error, + err_loc, + })) + } + } +} + +#[derive(Clone, Debug)] +pub struct TrappableErrorConfField { + pub abi_error: Ident, + pub rich_error: Ident, + pub err_loc: Span, +} + #[derive(Clone)] -pub struct ErrorConfField { +pub struct UserErrorConfField { pub abi_error: Ident, pub rich_error: syn::Path, pub err_loc: Span, } -impl std::fmt::Debug for ErrorConfField { +impl std::fmt::Debug for UserErrorConfField { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ErrorConfField") .field("abi_error", &self.abi_error) @@ -308,20 +362,6 @@ impl std::fmt::Debug for ErrorConfField { } } -impl Parse for ErrorConfField { - fn parse(input: ParseStream) -> Result { - let err_loc = input.span(); - let abi_error = input.parse::()?; - let _arrow: Token![=>] = input.parse()?; - let rich_error = input.parse::()?; - Ok(ErrorConfField { - abi_error, - rich_error, - err_loc, - }) - } -} - #[derive(Clone, Default, Debug)] /// Modules and funcs that have async signatures pub struct AsyncConf { diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index bd1f1eebdf40..63177c3c73ef 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -1,7 +1,7 @@ -use crate::codegen_settings::CodegenSettings; +use crate::codegen_settings::{CodegenSettings, ErrorType}; use crate::lifetimes::anon_lifetime; use crate::module_trait::passed_by_reference; -use crate::names::Names; +use crate::names; use crate::types::WiggleType; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -9,54 +9,50 @@ use std::mem; use witx::Instruction; pub fn define_func( - names: &Names, module: &witx::Module, func: &witx::InterfaceFunc, settings: &CodegenSettings, ) -> TokenStream { - let (ts, _bounds) = _define_func(names, module, func, settings); + let (ts, _bounds) = _define_func(module, func, settings); ts } pub fn func_bounds( - names: &Names, module: &witx::Module, func: &witx::InterfaceFunc, settings: &CodegenSettings, ) -> Vec { - let (_ts, bounds) = _define_func(names, module, func, settings); + let (_ts, bounds) = _define_func(module, func, settings); bounds } fn _define_func( - names: &Names, module: &witx::Module, func: &witx::InterfaceFunc, settings: &CodegenSettings, ) -> (TokenStream, Vec) { - let rt = names.runtime_mod(); - let ident = names.func(&func.name); + let ident = names::func(&func.name); let (wasm_params, wasm_results) = func.wasm_signature(); let param_names = (0..wasm_params.len()) .map(|i| Ident::new(&format!("arg{}", i), Span::call_site())) .collect::>(); let abi_params = wasm_params.iter().zip(¶m_names).map(|(arg, name)| { - let wasm = names.wasm_type(*arg); + let wasm = names::wasm_type(*arg); quote!(#name : #wasm) }); let abi_ret = match wasm_results.len() { 0 => quote!(()), 1 => { - let ty = names.wasm_type(wasm_results[0]); + let ty = names::wasm_type(wasm_results[0]); quote!(#ty) } _ => unimplemented!(), }; let mut body = TokenStream::new(); - let mut bounds = vec![names.trait_name(&module.name)]; + let mut bounds = vec![names::trait_name(&module.name)]; func.call_interface( &module.name, &mut Rust { @@ -64,8 +60,6 @@ fn _define_func( params: ¶m_names, block_storage: Vec::new(), blocks: Vec::new(), - rt: &rt, - names, module, funcname: func.name.as_str(), settings, @@ -76,8 +70,8 @@ fn _define_func( let mod_name = &module.name.as_str(); let func_name = &func.name.as_str(); let mk_span = quote!( - let _span = #rt::tracing::span!( - #rt::tracing::Level::TRACE, + let _span = wiggle::tracing::span!( + wiggle::tracing::Level::TRACE, "wiggle abi", module = #mod_name, function = #func_name @@ -99,9 +93,9 @@ fn _define_func( #[allow(unreachable_code)] // deals with warnings in noreturn functions pub fn #ident( ctx: &mut (impl #(#bounds)+*), - memory: &dyn #rt::GuestMemory, + memory: &dyn wiggle::GuestMemory, #(#abi_params),* - ) -> #rt::anyhow::Result<#abi_ret> { + ) -> wiggle::anyhow::Result<#abi_ret> { use std::convert::TryFrom as _; #traced_body } @@ -111,7 +105,7 @@ fn _define_func( } else { let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) { quote!( - use #rt::tracing::Instrument as _; + use wiggle::tracing::Instrument as _; #mk_span async move { #body @@ -129,9 +123,9 @@ fn _define_func( #[allow(unreachable_code)] // deals with warnings in noreturn functions pub fn #ident<'a>( ctx: &'a mut (impl #(#bounds)+*), - memory: &'a dyn #rt::GuestMemory, + memory: &'a dyn wiggle::GuestMemory, #(#abi_params),* - ) -> impl std::future::Future> + 'a { + ) -> impl std::future::Future> + 'a { use std::convert::TryFrom as _; #traced_body } @@ -146,8 +140,6 @@ struct Rust<'a> { params: &'a [Ident], block_storage: Vec, blocks: Vec, - rt: &'a TokenStream, - names: &'a Names, module: &'a witx::Module, funcname: &'a str, settings: &'a CodegenSettings, @@ -196,17 +188,16 @@ impl witx::Bindgen for Rust<'_> { operands: &mut Vec, results: &mut Vec, ) { - let rt = self.rt; let wrap_err = |location: &str| { let modulename = self.module.name.as_str(); let funcname = self.funcname; quote! { |e| { - #rt::GuestError::InFunc { + wiggle::GuestError::InFunc { modulename: #modulename, funcname: #funcname, location: #location, - err: Box::new(#rt::GuestError::from(e)), + err: Box::new(wiggle::GuestError::from(e)), } } } @@ -226,9 +217,9 @@ impl witx::Bindgen for Rust<'_> { Instruction::PointerFromI32 { ty } | Instruction::ConstPointerFromI32 { ty } => { let val = operands.pop().unwrap(); - let pointee_type = self.names.type_ref(ty, anon_lifetime()); + let pointee_type = names::type_ref(ty, anon_lifetime()); results.push(quote! { - #rt::GuestPtr::<#pointee_type>::new(memory, #val as u32) + wiggle::GuestPtr::<#pointee_type>::new(memory, #val as u32) }); } @@ -238,12 +229,12 @@ impl witx::Bindgen for Rust<'_> { let ty = match &**ty.type_() { witx::Type::Builtin(witx::BuiltinType::Char) => quote!(str), _ => { - let ty = self.names.type_ref(ty, anon_lifetime()); + let ty = names::type_ref(ty, anon_lifetime()); quote!([#ty]) } }; results.push(quote! { - #rt::GuestPtr::<#ty>::new(memory, (#ptr as u32, #len as u32)); + wiggle::GuestPtr::<#ty>::new(memory, (#ptr as u32, #len as u32)); }) } @@ -252,7 +243,7 @@ impl witx::Bindgen for Rust<'_> { // out, and afterwards we call the function with those bindings. let mut args = Vec::new(); for (i, param) in func.params.iter().enumerate() { - let name = self.names.func_param(¶m.name); + let name = names::func_param(¶m.name); let val = &operands[i]; self.src.extend(quote!(let #name = #val;)); if passed_by_reference(param.tref.type_()) { @@ -271,21 +262,21 @@ impl witx::Bindgen for Rust<'_> { .params .iter() .map(|param| { - let name = self.names.func_param(¶m.name); + let name = names::func_param(¶m.name); if param.impls_display() { - quote!( #name = #rt::tracing::field::display(&#name) ) + quote!( #name = wiggle::tracing::field::display(&#name) ) } else { - quote!( #name = #rt::tracing::field::debug(&#name) ) + quote!( #name = wiggle::tracing::field::debug(&#name) ) } }) .collect::>(); self.src.extend(quote! { - #rt::tracing::event!(#rt::tracing::Level::TRACE, #(#args),*); + wiggle::tracing::event!(wiggle::tracing::Level::TRACE, #(#args),*); }); } - let trait_name = self.names.trait_name(&self.module.name); - let ident = self.names.func(&func.name); + let trait_name = names::trait_name(&self.module.name); + let ident = names::func(&func.name); if self.settings.get_async(&self.module, &func).is_sync() { self.src.extend(quote! { let ret = #trait_name::#ident(ctx, #(#args),*); @@ -301,9 +292,9 @@ impl witx::Bindgen for Rust<'_> { .enabled_for(self.module.name.as_str(), self.funcname) { self.src.extend(quote! { - #rt::tracing::event!( - #rt::tracing::Level::TRACE, - result = #rt::tracing::field::debug(&ret), + wiggle::tracing::event!( + wiggle::tracing::Level::TRACE, + result = wiggle::tracing::field::debug(&ret), ); }); } @@ -322,11 +313,12 @@ impl witx::Bindgen for Rust<'_> { Instruction::EnumLower { ty } => { let val = operands.pop().unwrap(); let val = match self.settings.errors.for_name(ty) { - Some(custom) => { - let method = self.names.user_error_conversion_method(&custom); + Some(ErrorType::User(custom)) => { + let method = names::user_error_conversion_method(&custom); self.bound(quote::format_ident!("UserErrorConversion")); quote!(UserErrorConversion::#method(ctx, #val)?) } + Some(ErrorType::Generated(_)) => quote!(#val.downcast()?), None => val, }; results.push(quote!(#val as i32)); @@ -336,10 +328,10 @@ impl witx::Bindgen for Rust<'_> { let err = self.blocks.pop().unwrap(); let ok = self.blocks.pop().unwrap(); let val = operands.pop().unwrap(); - let err_typename = self.names.type_ref(err_ty.unwrap(), anon_lifetime()); + let err_typename = names::type_ref(err_ty.unwrap(), anon_lifetime()); results.push(quote! { match #val { - Ok(e) => { #ok; <#err_typename as #rt::GuestErrorType>::success() as i32 } + Ok(e) => { #ok; <#err_typename as wiggle::GuestErrorType>::success() as i32 } Err(e) => { #err } } }); @@ -369,9 +361,9 @@ impl witx::Bindgen for Rust<'_> { let ptr = operands.pop().unwrap(); let val = operands.pop().unwrap(); let wrap_err = wrap_err(&format!("write {}", ty.name.as_str())); - let pointee_type = self.names.type_(&ty.name); + let pointee_type = names::type_(&ty.name); self.src.extend(quote! { - #rt::GuestPtr::<#pointee_type>::new(memory, #ptr as u32) + wiggle::GuestPtr::<#pointee_type>::new(memory, #ptr as u32) .write(#val) .map_err(#wrap_err)?; }); @@ -380,9 +372,9 @@ impl witx::Bindgen for Rust<'_> { Instruction::Load { ty } => { let ptr = operands.pop().unwrap(); let wrap_err = wrap_err(&format!("read {}", ty.name.as_str())); - let pointee_type = self.names.type_(&ty.name); + let pointee_type = names::type_(&ty.name); results.push(quote! { - #rt::GuestPtr::<#pointee_type>::new(memory, #ptr as u32) + wiggle::GuestPtr::<#pointee_type>::new(memory, #ptr as u32) .read() .map_err(#wrap_err)? }); @@ -390,7 +382,7 @@ impl witx::Bindgen for Rust<'_> { Instruction::HandleFromI32 { ty } => { let val = operands.pop().unwrap(); - let ty = self.names.type_(&ty.name); + let ty = names::type_(&ty.name); results.push(quote!(#ty::from(#val))); } @@ -418,7 +410,7 @@ impl witx::Bindgen for Rust<'_> { Instruction::EnumLift { ty } | Instruction::BitflagsFromI64 { ty } | Instruction::BitflagsFromI32 { ty } => { - let ty = self.names.type_(&ty.name); + let ty = names::type_(&ty.name); try_from(quote!(#ty)) } diff --git a/crates/wiggle/generate/src/lib.rs b/crates/wiggle/generate/src/lib.rs index 9b5cea7579ce..77bbe6b005c9 100644 --- a/crates/wiggle/generate/src/lib.rs +++ b/crates/wiggle/generate/src/lib.rs @@ -3,7 +3,7 @@ pub mod config; mod funcs; mod lifetimes; mod module_trait; -mod names; +pub mod names; mod types; pub mod wasmtime; @@ -12,19 +12,16 @@ use lifetimes::anon_lifetime; use proc_macro2::{Literal, TokenStream}; use quote::quote; -pub use codegen_settings::{CodegenSettings, UserErrorType}; +pub use codegen_settings::{CodegenSettings, ErrorType, UserErrorType}; pub use config::{Config, WasmtimeConfig}; pub use funcs::define_func; pub use module_trait::define_module_trait; -pub use names::Names; pub use types::define_datatype; -pub fn generate(doc: &witx::Document, names: &Names, settings: &CodegenSettings) -> TokenStream { - // TODO at some point config should grow more ability to configure name - // overrides. - let rt = names.runtime_mod(); - - let types = doc.typenames().map(|t| define_datatype(&names, &t)); +pub fn generate(doc: &witx::Document, settings: &CodegenSettings) -> TokenStream { + let types = doc + .typenames() + .map(|t| define_datatype(&t, settings.errors.for_name(&t))); let constants = doc.constants().map(|c| { let name = quote::format_ident!( @@ -32,18 +29,24 @@ pub fn generate(doc: &witx::Document, names: &Names, settings: &CodegenSettings) c.ty.as_str().to_shouty_snake_case(), c.name.as_str().to_shouty_snake_case() ); - let ty = names.type_(&c.ty); + let ty = names::type_(&c.ty); let value = Literal::u64_unsuffixed(c.value); quote! { pub const #name: #ty = #value; } }); - let user_error_methods = settings.errors.iter().map(|errtype| { - let abi_typename = names.type_ref(&errtype.abi_type(), anon_lifetime()); - let user_typename = errtype.typename(); - let methodname = names.user_error_conversion_method(&errtype); - quote!(fn #methodname(&mut self, e: super::#user_typename) -> #rt::anyhow::Result<#abi_typename>;) + let user_error_methods = settings.errors.iter().filter_map(|errtype| match errtype { + ErrorType::User(errtype) => { + let abi_typename = names::type_ref(&errtype.abi_type(), anon_lifetime()); + let user_typename = errtype.typename(); + let methodname = names::user_error_conversion_method(&errtype); + Some(quote! { + fn #methodname(&mut self, e: super::#user_typename) + -> wiggle::anyhow::Result<#abi_typename>; + }) + } + ErrorType::Generated(_) => None, }); let user_error_conversion = quote! { pub trait UserErrorConversion { @@ -51,13 +54,11 @@ pub fn generate(doc: &witx::Document, names: &Names, settings: &CodegenSettings) } }; let modules = doc.modules().map(|module| { - let modname = names.module(&module.name); - let fs = module - .funcs() - .map(|f| define_func(&names, &module, &f, &settings)); - let modtrait = define_module_trait(&names, &module, &settings); + let modname = names::module(&module.name); + let fs = module.funcs().map(|f| define_func(&module, &f, &settings)); + let modtrait = define_module_trait(&module, &settings); let wasmtime = if settings.wasmtime { - crate::wasmtime::link_module(&module, &names, None, &settings) + crate::wasmtime::link_module(&module, None, &settings) } else { quote! {} }; @@ -86,14 +87,13 @@ pub fn generate(doc: &witx::Document, names: &Names, settings: &CodegenSettings) ) } -pub fn generate_metadata(doc: &witx::Document, names: &Names) -> TokenStream { - let rt = names.runtime_mod(); +pub fn generate_metadata(doc: &witx::Document) -> TokenStream { let doc_text = &format!("{}", doc); quote! { pub mod metadata { pub const DOC_TEXT: &str = #doc_text; - pub fn document() -> #rt::witx::Document { - #rt::witx::parse(DOC_TEXT).unwrap() + pub fn document() -> wiggle::witx::Document { + wiggle::witx::parse(DOC_TEXT).unwrap() } } } diff --git a/crates/wiggle/generate/src/module_trait.rs b/crates/wiggle/generate/src/module_trait.rs index be9a47ad3525..c1dbfac21fc2 100644 --- a/crates/wiggle/generate/src/module_trait.rs +++ b/crates/wiggle/generate/src/module_trait.rs @@ -1,9 +1,9 @@ use proc_macro2::TokenStream; use quote::quote; -use crate::codegen_settings::CodegenSettings; +use crate::codegen_settings::{CodegenSettings, ErrorType}; use crate::lifetimes::{anon_lifetime, LifetimeExt}; -use crate::names::Names; +use crate::names; use witx::Module; pub fn passed_by_reference(ty: &witx::Type) -> bool { @@ -15,9 +15,8 @@ pub fn passed_by_reference(ty: &witx::Type) -> bool { } } -pub fn define_module_trait(names: &Names, m: &Module, settings: &CodegenSettings) -> TokenStream { - let traitname = names.trait_name(&m.name); - let rt = names.runtime_mod(); +pub fn define_module_trait(m: &Module, settings: &CodegenSettings) -> TokenStream { + let traitname = names::trait_name(&m.name); let traitmethods = m.funcs().map(|f| { // Check if we're returning an entity anotated with a lifetime, // in which case, we'll need to annotate the function itself, and @@ -32,10 +31,10 @@ pub fn define_module_trait(names: &Names, m: &Module, settings: &CodegenSettings } else { (anon_lifetime(), true) }; - let funcname = names.func(&f.name); + let funcname = names::func(&f.name); let args = f.params.iter().map(|arg| { - let arg_name = names.func_param(&arg.name); - let arg_typename = names.type_ref(&arg.tref, lifetime.clone()); + let arg_name = names::func_param(&arg.name); + let arg_typename = names::type_ref(&arg.tref, lifetime.clone()); let arg_type = if passed_by_reference(&*arg.tref.type_()) { quote!(&#arg_typename) } else { @@ -45,7 +44,7 @@ pub fn define_module_trait(names: &Names, m: &Module, settings: &CodegenSettings }); let result = match f.results.len() { - 0 if f.noreturn => quote!(#rt::anyhow::Error), + 0 if f.noreturn => quote!(wiggle::anyhow::Error), 0 => quote!(()), 1 => { let (ok, err) = match &**f.results[0].tref.type_() { @@ -57,16 +56,17 @@ pub fn define_module_trait(names: &Names, m: &Module, settings: &CodegenSettings }; let ok = match ok { - Some(ty) => names.type_ref(ty, lifetime.clone()), + Some(ty) => names::type_ref(ty, lifetime.clone()), None => quote!(()), }; let err = match err { Some(ty) => match settings.errors.for_abi_error(ty) { - Some(custom) => { + Some(ErrorType::User(custom)) => { let tn = custom.typename(); quote!(super::#tn) } - None => names.type_ref(ty, lifetime.clone()), + Some(ErrorType::Generated(g)) => g.typename(), + None => names::type_ref(ty, lifetime.clone()), }, None => quote!(()), }; @@ -89,7 +89,7 @@ pub fn define_module_trait(names: &Names, m: &Module, settings: &CodegenSettings }); quote! { - #[#rt::async_trait] + #[wiggle::async_trait] pub trait #traitname { #(#traitmethods)* } diff --git a/crates/wiggle/generate/src/names.rs b/crates/wiggle/generate/src/names.rs index 635ba735ad65..a32a35be5740 100644 --- a/crates/wiggle/generate/src/names.rs +++ b/crates/wiggle/generate/src/names.rs @@ -6,204 +6,186 @@ use witx::{BuiltinType, Id, Type, TypeRef, WasmType}; use crate::{lifetimes::LifetimeExt, UserErrorType}; -pub struct Names { - runtime_mod: TokenStream, +pub fn type_(id: &Id) -> Ident { + escape_id(id, NamingConvention::CamelCase) } -impl Names { - pub fn new(runtime_mod: TokenStream) -> Names { - Names { runtime_mod } - } - - pub fn runtime_mod(&self) -> TokenStream { - self.runtime_mod.clone() +pub fn builtin_type(b: BuiltinType) -> TokenStream { + match b { + BuiltinType::U8 { .. } => quote!(u8), + BuiltinType::U16 => quote!(u16), + BuiltinType::U32 { .. } => quote!(u32), + BuiltinType::U64 => quote!(u64), + BuiltinType::S8 => quote!(i8), + BuiltinType::S16 => quote!(i16), + BuiltinType::S32 => quote!(i32), + BuiltinType::S64 => quote!(i64), + BuiltinType::F32 => quote!(f32), + BuiltinType::F64 => quote!(f64), + BuiltinType::Char => quote!(char), } +} - pub fn type_(&self, id: &Id) -> TokenStream { - let ident = escape_id(id, NamingConvention::CamelCase); - quote!(#ident) - } - - pub fn builtin_type(&self, b: BuiltinType) -> TokenStream { - match b { - BuiltinType::U8 { .. } => quote!(u8), - BuiltinType::U16 => quote!(u16), - BuiltinType::U32 { .. } => quote!(u32), - BuiltinType::U64 => quote!(u64), - BuiltinType::S8 => quote!(i8), - BuiltinType::S16 => quote!(i16), - BuiltinType::S32 => quote!(i32), - BuiltinType::S64 => quote!(i64), - BuiltinType::F32 => quote!(f32), - BuiltinType::F64 => quote!(f64), - BuiltinType::Char => quote!(char), - } +pub fn wasm_type(ty: WasmType) -> TokenStream { + match ty { + WasmType::I32 => quote!(i32), + WasmType::I64 => quote!(i64), + WasmType::F32 => quote!(f32), + WasmType::F64 => quote!(f64), } +} - pub fn wasm_type(&self, ty: WasmType) -> TokenStream { - match ty { - WasmType::I32 => quote!(i32), - WasmType::I64 => quote!(i64), - WasmType::F32 => quote!(f32), - WasmType::F64 => quote!(f64), +pub fn type_ref(tref: &TypeRef, lifetime: TokenStream) -> TokenStream { + match tref { + TypeRef::Name(nt) => { + let ident = type_(&nt.name); + if nt.tref.needs_lifetime() { + quote!(#ident<#lifetime>) + } else { + quote!(#ident) + } } - } - - pub fn type_ref(&self, tref: &TypeRef, lifetime: TokenStream) -> TokenStream { - match tref { - TypeRef::Name(nt) => { - let ident = self.type_(&nt.name); - if nt.tref.needs_lifetime() { - quote!(#ident<#lifetime>) - } else { - quote!(#ident) - } + TypeRef::Value(ty) => match &**ty { + Type::Builtin(builtin) => builtin_type(*builtin), + Type::Pointer(pointee) | Type::ConstPointer(pointee) => { + let pointee_type = type_ref(&pointee, lifetime.clone()); + quote!(wiggle::GuestPtr<#lifetime, #pointee_type>) } - TypeRef::Value(ty) => match &**ty { - Type::Builtin(builtin) => self.builtin_type(*builtin), - Type::Pointer(pointee) | Type::ConstPointer(pointee) => { - let rt = self.runtime_mod(); - let pointee_type = self.type_ref(&pointee, lifetime.clone()); - quote!(#rt::GuestPtr<#lifetime, #pointee_type>) + Type::List(pointee) => match &**pointee.type_() { + Type::Builtin(BuiltinType::Char) => { + quote!(wiggle::GuestPtr<#lifetime, str>) } - Type::List(pointee) => match &**pointee.type_() { - Type::Builtin(BuiltinType::Char) => { - let rt = self.runtime_mod(); - quote!(#rt::GuestPtr<#lifetime, str>) - } - _ => { - let rt = self.runtime_mod(); - let pointee_type = self.type_ref(&pointee, lifetime.clone()); - quote!(#rt::GuestPtr<#lifetime, [#pointee_type]>) - } - }, - Type::Variant(v) => match v.as_expected() { - Some((ok, err)) => { - let ok = match ok { - Some(ty) => self.type_ref(ty, lifetime.clone()), - None => quote!(()), - }; - let err = match err { - Some(ty) => self.type_ref(ty, lifetime.clone()), - None => quote!(()), - }; - quote!(Result<#ok, #err>) - } - None => unimplemented!("anonymous variant ref {:?}", tref), - }, - Type::Record(r) if r.is_tuple() => { - let types = r - .members - .iter() - .map(|m| self.type_ref(&m.tref, lifetime.clone())) - .collect::>(); - quote!((#(#types,)*)) + _ => { + let pointee_type = type_ref(&pointee, lifetime.clone()); + quote!(wiggle::GuestPtr<#lifetime, [#pointee_type]>) } - _ => unimplemented!("anonymous type ref {:?}", tref), }, - } + Type::Variant(v) => match v.as_expected() { + Some((ok, err)) => { + let ok = match ok { + Some(ty) => type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + let err = match err { + Some(ty) => type_ref(ty, lifetime.clone()), + None => quote!(()), + }; + quote!(Result<#ok, #err>) + } + None => unimplemented!("anonymous variant ref {:?}", tref), + }, + Type::Record(r) if r.is_tuple() => { + let types = r + .members + .iter() + .map(|m| type_ref(&m.tref, lifetime.clone())) + .collect::>(); + quote!((#(#types,)*)) + } + _ => unimplemented!("anonymous type ref {:?}", tref), + }, } +} - /// Convert an enum variant from its [`Id`][witx] name to its Rust [`Ident`][id] representation. - /// - /// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html - /// [witx]: https://docs.rs/witx/*/witx/struct.Id.html - pub fn enum_variant(&self, id: &Id) -> Ident { - handle_2big_enum_variant(id).unwrap_or_else(|| escape_id(id, NamingConvention::CamelCase)) - } +/// Convert an enum variant from its [`Id`][witx] name to its Rust [`Ident`][id] representation. +/// +/// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html +/// [witx]: https://docs.rs/witx/*/witx/struct.Id.html +pub fn enum_variant(id: &Id) -> Ident { + handle_2big_enum_variant(id).unwrap_or_else(|| escape_id(id, NamingConvention::CamelCase)) +} - pub fn flag_member(&self, id: &Id) -> Ident { - format_ident!("{}", id.as_str().to_shouty_snake_case()) - } +pub fn flag_member(id: &Id) -> Ident { + format_ident!("{}", id.as_str().to_shouty_snake_case()) +} - pub fn int_member(&self, id: &Id) -> Ident { - format_ident!("{}", id.as_str().to_shouty_snake_case()) - } +pub fn int_member(id: &Id) -> Ident { + format_ident!("{}", id.as_str().to_shouty_snake_case()) +} - /// Convert a struct member from its [`Id`][witx] name to its Rust [`Ident`][id] representation. - /// - /// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html - /// [witx]: https://docs.rs/witx/*/witx/struct.Id.html - pub fn struct_member(&self, id: &Id) -> Ident { - escape_id(id, NamingConvention::SnakeCase) - } +/// Convert a struct member from its [`Id`][witx] name to its Rust [`Ident`][id] representation. +/// +/// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html +/// [witx]: https://docs.rs/witx/*/witx/struct.Id.html +pub fn struct_member(id: &Id) -> Ident { + escape_id(id, NamingConvention::SnakeCase) +} - /// Convert a module name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. - /// - /// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html - /// [witx]: https://docs.rs/witx/*/witx/struct.Id.html - pub fn module(&self, id: &Id) -> Ident { - escape_id(id, NamingConvention::SnakeCase) - } +/// Convert a module name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. +/// +/// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html +/// [witx]: https://docs.rs/witx/*/witx/struct.Id.html +pub fn module(id: &Id) -> Ident { + escape_id(id, NamingConvention::SnakeCase) +} - /// Convert a trait name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. - /// - /// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html - /// [witx]: https://docs.rs/witx/*/witx/struct.Id.html - pub fn trait_name(&self, id: &Id) -> Ident { - escape_id(id, NamingConvention::CamelCase) - } +/// Convert a trait name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. +/// +/// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html +/// [witx]: https://docs.rs/witx/*/witx/struct.Id.html +pub fn trait_name(id: &Id) -> Ident { + escape_id(id, NamingConvention::CamelCase) +} - /// Convert a function name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. - /// - /// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html - /// [witx]: https://docs.rs/witx/*/witx/struct.Id.html - pub fn func(&self, id: &Id) -> Ident { - escape_id(id, NamingConvention::SnakeCase) - } +/// Convert a function name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. +/// +/// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html +/// [witx]: https://docs.rs/witx/*/witx/struct.Id.html +pub fn func(id: &Id) -> Ident { + escape_id(id, NamingConvention::SnakeCase) +} - /// Convert a parameter name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. - /// - /// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html - /// [witx]: https://docs.rs/witx/*/witx/struct.Id.html - pub fn func_param(&self, id: &Id) -> Ident { - escape_id(id, NamingConvention::SnakeCase) - } +/// Convert a parameter name from its [`Id`][witx] name to its Rust [`Ident`][id] representation. +/// +/// [id]: https://docs.rs/proc-macro2/*/proc_macro2/struct.Ident.html +/// [witx]: https://docs.rs/witx/*/witx/struct.Id.html +pub fn func_param(id: &Id) -> Ident { + escape_id(id, NamingConvention::SnakeCase) +} - /// For when you need a {name}_ptr binding for passing a value by reference: - pub fn func_ptr_binding(&self, id: &Id) -> Ident { - format_ident!("{}_ptr", id.as_str().to_snake_case()) - } +/// For when you need a {name}_ptr binding for passing a value by reference: +pub fn func_ptr_binding(id: &Id) -> Ident { + format_ident!("{}_ptr", id.as_str().to_snake_case()) +} - /// For when you need a {name}_len binding for passing an array: - pub fn func_len_binding(&self, id: &Id) -> Ident { - format_ident!("{}_len", id.as_str().to_snake_case()) - } +/// For when you need a {name}_len binding for passing an array: +pub fn func_len_binding(id: &Id) -> Ident { + format_ident!("{}_len", id.as_str().to_snake_case()) +} - fn builtin_name(b: &BuiltinType) -> &'static str { - match b { - BuiltinType::U8 { .. } => "u8", - BuiltinType::U16 => "u16", - BuiltinType::U32 { .. } => "u32", - BuiltinType::U64 => "u64", - BuiltinType::S8 => "i8", - BuiltinType::S16 => "i16", - BuiltinType::S32 => "i32", - BuiltinType::S64 => "i64", - BuiltinType::F32 => "f32", - BuiltinType::F64 => "f64", - BuiltinType::Char => "char", - } +fn builtin_name(b: &BuiltinType) -> &'static str { + match b { + BuiltinType::U8 { .. } => "u8", + BuiltinType::U16 => "u16", + BuiltinType::U32 { .. } => "u32", + BuiltinType::U64 => "u64", + BuiltinType::S8 => "i8", + BuiltinType::S16 => "i16", + BuiltinType::S32 => "i32", + BuiltinType::S64 => "i64", + BuiltinType::F32 => "f32", + BuiltinType::F64 => "f64", + BuiltinType::Char => "char", } +} - fn snake_typename(tref: &TypeRef) -> String { - match tref { - TypeRef::Name(nt) => nt.name.as_str().to_snake_case(), - TypeRef::Value(ty) => match &**ty { - Type::Builtin(b) => Self::builtin_name(&b).to_owned(), - _ => panic!("unexpected anonymous type: {:?}", ty), - }, - } +fn snake_typename(tref: &TypeRef) -> String { + match tref { + TypeRef::Name(nt) => nt.name.as_str().to_snake_case(), + TypeRef::Value(ty) => match &**ty { + Type::Builtin(b) => builtin_name(&b).to_owned(), + _ => panic!("unexpected anonymous type: {:?}", ty), + }, } +} - pub fn user_error_conversion_method(&self, user_type: &UserErrorType) -> Ident { - let abi_type = Self::snake_typename(&user_type.abi_type()); - format_ident!( - "{}_from_{}", - abi_type, - user_type.method_fragment().to_snake_case() - ) - } +pub fn user_error_conversion_method(user_type: &UserErrorType) -> Ident { + let abi_type = snake_typename(&user_type.abi_type()); + format_ident!( + "{}_from_{}", + abi_type, + user_type.method_fragment().to_snake_case() + ) } /// Identifier escaping utilities. diff --git a/crates/wiggle/generate/src/types/error.rs b/crates/wiggle/generate/src/types/error.rs new file mode 100644 index 000000000000..9f0d51f7a031 --- /dev/null +++ b/crates/wiggle/generate/src/types/error.rs @@ -0,0 +1,53 @@ +use crate::codegen_settings::TrappableErrorType; +use crate::names; + +use proc_macro2::TokenStream; +use quote::quote; + +pub(super) fn define_error( + name: &witx::Id, + _v: &witx::Variant, + e: &TrappableErrorType, +) -> TokenStream { + let abi_error = names::type_(name); + let rich_error = e.typename(); + + quote! { + #[derive(Debug)] + pub struct #rich_error { + inner: anyhow::Error, + } + + impl std::fmt::Display for #rich_error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.inner) + } + } + impl std::error::Error for #rich_error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.inner.source() + } + } + + impl #rich_error { + pub fn trap(inner: anyhow::Error) -> #rich_error { + Self { inner } + } + pub fn downcast(self) -> Result<#abi_error, anyhow::Error> { + self.inner.downcast() + } + pub fn downcast_ref(&self) -> Option<&#abi_error> { + self.inner.downcast_ref() + } + pub fn context(self, s: impl Into) -> Self { + Self { inner: self.inner.context(s.into()) } + } + } + + impl From<#abi_error> for #rich_error { + fn from(abi: #abi_error) -> #rich_error { + #rich_error { inner: anyhow::Error::from(abi) } + } + } + } +} diff --git a/crates/wiggle/generate/src/types/flags.rs b/crates/wiggle/generate/src/types/flags.rs index 34e35eeb2dd2..a55d3f00294a 100644 --- a/crates/wiggle/generate/src/types/flags.rs +++ b/crates/wiggle/generate/src/types/flags.rs @@ -1,30 +1,28 @@ -use crate::names::Names; +use crate::names; use proc_macro2::{Literal, TokenStream}; use quote::quote; pub(super) fn define_flags( - names: &Names, name: &witx::Id, repr: witx::IntRepr, record: &witx::RecordDatatype, ) -> TokenStream { - let rt = names.runtime_mod(); - let ident = names.type_(&name); - let abi_repr = names.wasm_type(repr.into()); + let ident = names::type_(&name); + let abi_repr = names::wasm_type(repr.into()); let repr = super::int_repr_tokens(repr); let mut names_ = vec![]; let mut values_ = vec![]; for (i, member) in record.members.iter().enumerate() { - let name = names.flag_member(&member.name); + let name = names::flag_member(&member.name); let value_token = Literal::usize_unsuffixed(1 << i); names_.push(name); values_.push(value_token); } quote! { - #rt::bitflags::bitflags! { + wiggle::bitflags::bitflags! { pub struct #ident: #repr { #(const #names_ = #values_;)* } @@ -43,10 +41,10 @@ pub(super) fn define_flags( } impl TryFrom<#repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #repr) -> Result { + type Error = wiggle::GuestError; + fn try_from(value: #repr) -> Result { if #repr::from(!#ident::all()) & value != 0 { - Err(#rt::GuestError::InvalidFlagValue(stringify!(#ident))) + Err(wiggle::GuestError::InvalidFlagValue(stringify!(#ident))) } else { Ok(#ident { bits: value }) } @@ -54,8 +52,8 @@ pub(super) fn define_flags( } impl TryFrom<#abi_repr> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #abi_repr) -> Result { + type Error = wiggle::GuestError; + fn try_from(value: #abi_repr) -> Result { #ident::try_from(#repr::try_from(value)?) } } @@ -66,7 +64,7 @@ pub(super) fn define_flags( } } - impl<'a> #rt::GuestType<'a> for #ident { + impl<'a> wiggle::GuestType<'a> for #ident { fn guest_size() -> u32 { #repr::guest_size() } @@ -75,14 +73,14 @@ pub(super) fn define_flags( #repr::guest_align() } - fn read(location: &#rt::GuestPtr<#ident>) -> Result<#ident, #rt::GuestError> { + fn read(location: &wiggle::GuestPtr<#ident>) -> Result<#ident, wiggle::GuestError> { use std::convert::TryFrom; let reprval = #repr::read(&location.cast())?; let value = #ident::try_from(reprval)?; Ok(value) } - fn write(location: &#rt::GuestPtr<'_, #ident>, val: Self) -> Result<(), #rt::GuestError> { + fn write(location: &wiggle::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle::GuestError> { let val: #repr = #repr::from(val); #repr::write(&location.cast(), val) } diff --git a/crates/wiggle/generate/src/types/handle.rs b/crates/wiggle/generate/src/types/handle.rs index 7eb954c91e97..cf126484e51e 100644 --- a/crates/wiggle/generate/src/types/handle.rs +++ b/crates/wiggle/generate/src/types/handle.rs @@ -1,16 +1,11 @@ -use crate::names::Names; +use crate::names; use proc_macro2::TokenStream; use quote::quote; use witx::Layout; -pub(super) fn define_handle( - names: &Names, - name: &witx::Id, - h: &witx::HandleDatatype, -) -> TokenStream { - let rt = names.runtime_mod(); - let ident = names.type_(name); +pub(super) fn define_handle(name: &witx::Id, h: &witx::HandleDatatype) -> TokenStream { + let ident = names::type_(name); let size = h.mem_size_align().size as u32; let align = h.mem_size_align().align as usize; quote! { @@ -53,7 +48,7 @@ pub(super) fn define_handle( } } - impl<'a> #rt::GuestType<'a> for #ident { + impl<'a> wiggle::GuestType<'a> for #ident { fn guest_size() -> u32 { #size } @@ -62,11 +57,11 @@ pub(super) fn define_handle( #align } - fn read(location: &#rt::GuestPtr<'a, #ident>) -> Result<#ident, #rt::GuestError> { + fn read(location: &wiggle::GuestPtr<'a, #ident>) -> Result<#ident, wiggle::GuestError> { Ok(#ident(u32::read(&location.cast())?)) } - fn write(location: &#rt::GuestPtr<'_, Self>, val: Self) -> Result<(), #rt::GuestError> { + fn write(location: &wiggle::GuestPtr<'_, Self>, val: Self) -> Result<(), wiggle::GuestError> { u32::write(&location.cast(), val.0) } } diff --git a/crates/wiggle/generate/src/types/mod.rs b/crates/wiggle/generate/src/types/mod.rs index 255c2ce7b223..2820db9fbd29 100644 --- a/crates/wiggle/generate/src/types/mod.rs +++ b/crates/wiggle/generate/src/types/mod.rs @@ -1,42 +1,49 @@ // mod r#enum; +mod error; mod flags; mod handle; mod record; mod variant; +use crate::codegen_settings::ErrorType; use crate::lifetimes::LifetimeExt; -use crate::names::Names; +use crate::names; use proc_macro2::TokenStream; use quote::quote; -pub fn define_datatype(names: &Names, namedtype: &witx::NamedType) -> TokenStream { +pub fn define_datatype(namedtype: &witx::NamedType, error: Option<&ErrorType>) -> TokenStream { match &namedtype.tref { - witx::TypeRef::Name(alias_to) => define_alias(names, &namedtype.name, &alias_to), + witx::TypeRef::Name(alias_to) => define_alias(&namedtype.name, &alias_to), witx::TypeRef::Value(v) => match &**v { witx::Type::Record(r) => match r.bitflags_repr() { - Some(repr) => flags::define_flags(names, &namedtype.name, repr, &r), - None => record::define_struct(names, &namedtype.name, &r), + Some(repr) => flags::define_flags(&namedtype.name, repr, &r), + None => record::define_struct(&namedtype.name, &r), }, - witx::Type::Variant(v) => variant::define_variant(names, &namedtype.name, &v), - witx::Type::Handle(h) => handle::define_handle(names, &namedtype.name, &h), - witx::Type::Builtin(b) => define_builtin(names, &namedtype.name, *b), + witx::Type::Variant(v) => match error { + Some(ErrorType::Generated(error)) => { + let d = variant::define_variant(&namedtype.name, &v, true); + let e = error::define_error(&namedtype.name, &v, error); + quote!( #d #e ) + } + _ => variant::define_variant(&namedtype.name, &v, false), + }, + witx::Type::Handle(h) => handle::define_handle(&namedtype.name, &h), + witx::Type::Builtin(b) => define_builtin(&namedtype.name, *b), witx::Type::Pointer(p) => { - let rt = names.runtime_mod(); - define_witx_pointer(names, &namedtype.name, quote!(#rt::GuestPtr), p) + define_witx_pointer(&namedtype.name, quote!(wiggle::GuestPtr), p) } witx::Type::ConstPointer(p) => { - let rt = names.runtime_mod(); - define_witx_pointer(names, &namedtype.name, quote!(#rt::GuestPtr), p) + define_witx_pointer(&namedtype.name, quote!(wiggle::GuestPtr), p) } - witx::Type::List(arr) => define_witx_list(names, &namedtype.name, &arr), + witx::Type::List(arr) => define_witx_list(&namedtype.name, &arr), }, } } -fn define_alias(names: &Names, name: &witx::Id, to: &witx::NamedType) -> TokenStream { - let ident = names.type_(name); - let rhs = names.type_(&to.name); +fn define_alias(name: &witx::Id, to: &witx::NamedType) -> TokenStream { + let ident = names::type_(name); + let rhs = names::type_(&to.name); if to.tref.needs_lifetime() { quote!(pub type #ident<'a> = #rhs<'a>;) } else { @@ -44,29 +51,27 @@ fn define_alias(names: &Names, name: &witx::Id, to: &witx::NamedType) -> TokenSt } } -fn define_builtin(names: &Names, name: &witx::Id, builtin: witx::BuiltinType) -> TokenStream { - let ident = names.type_(name); - let built = names.builtin_type(builtin); +fn define_builtin(name: &witx::Id, builtin: witx::BuiltinType) -> TokenStream { + let ident = names::type_(name); + let built = names::builtin_type(builtin); quote!(pub type #ident = #built;) } fn define_witx_pointer( - names: &Names, name: &witx::Id, pointer_type: TokenStream, pointee: &witx::TypeRef, ) -> TokenStream { - let ident = names.type_(name); - let pointee_type = names.type_ref(pointee, quote!('a)); + let ident = names::type_(name); + let pointee_type = names::type_ref(pointee, quote!('a)); quote!(pub type #ident<'a> = #pointer_type<'a, #pointee_type>;) } -fn define_witx_list(names: &Names, name: &witx::Id, arr_raw: &witx::TypeRef) -> TokenStream { - let ident = names.type_(name); - let rt = names.runtime_mod(); - let pointee_type = names.type_ref(arr_raw, quote!('a)); - quote!(pub type #ident<'a> = #rt::GuestPtr<'a, [#pointee_type]>;) +fn define_witx_list(name: &witx::Id, arr_raw: &witx::TypeRef) -> TokenStream { + let ident = names::type_(name); + let pointee_type = names::type_ref(arr_raw, quote!('a)); + quote!(pub type #ident<'a> = wiggle::GuestPtr<'a, [#pointee_type]>;) } pub fn int_repr_tokens(int_repr: witx::IntRepr) -> TokenStream { diff --git a/crates/wiggle/generate/src/types/record.rs b/crates/wiggle/generate/src/types/record.rs index d50124d7eaed..a5b48d18204c 100644 --- a/crates/wiggle/generate/src/types/record.rs +++ b/crates/wiggle/generate/src/types/record.rs @@ -1,26 +1,21 @@ use crate::lifetimes::{anon_lifetime, LifetimeExt}; -use crate::names::Names; +use crate::names; use proc_macro2::TokenStream; use quote::quote; use witx::Layout; -pub(super) fn define_struct( - names: &Names, - name: &witx::Id, - s: &witx::RecordDatatype, -) -> TokenStream { - let rt = names.runtime_mod(); - let ident = names.type_(name); +pub(super) fn define_struct(name: &witx::Id, s: &witx::RecordDatatype) -> TokenStream { + let ident = names::type_(name); let size = s.mem_size_align().size as u32; let align = s.mem_size_align().align as usize; - let member_names = s.members.iter().map(|m| names.struct_member(&m.name)); + let member_names = s.members.iter().map(|m| names::struct_member(&m.name)); let member_decls = s.members.iter().map(|m| { - let name = names.struct_member(&m.name); + let name = names::struct_member(&m.name); let type_ = match &m.tref { witx::TypeRef::Name(nt) => { - let tt = names.type_(&nt.name); + let tt = names::type_(&nt.name); if m.tref.needs_lifetime() { quote!(#tt<'a>) } else { @@ -28,10 +23,10 @@ pub(super) fn define_struct( } } witx::TypeRef::Value(ty) => match &**ty { - witx::Type::Builtin(builtin) => names.builtin_type(*builtin), + witx::Type::Builtin(builtin) => names::builtin_type(*builtin), witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { - let pointee_type = names.type_ref(&pointee, quote!('a)); - quote!(#rt::GuestPtr<'a, #pointee_type>) + let pointee_type = names::type_ref(&pointee, quote!('a)); + quote!(wiggle::GuestPtr<'a, #pointee_type>) } _ => unimplemented!("other anonymous struct members: {:?}", m.tref), }, @@ -40,27 +35,27 @@ pub(super) fn define_struct( }); let member_reads = s.member_layout().into_iter().map(|ml| { - let name = names.struct_member(&ml.member.name); + let name = names::struct_member(&ml.member.name); let offset = ml.offset as u32; let location = quote!(location.cast::().add(#offset)?.cast()); match &ml.member.tref { witx::TypeRef::Name(nt) => { - let type_ = names.type_(&nt.name); + let type_ = names::type_(&nt.name); quote! { - let #name = <#type_ as #rt::GuestType>::read(&#location)?; + let #name = <#type_ as wiggle::GuestType>::read(&#location)?; } } witx::TypeRef::Value(ty) => match &**ty { witx::Type::Builtin(builtin) => { - let type_ = names.builtin_type(*builtin); + let type_ = names::builtin_type(*builtin); quote! { - let #name = <#type_ as #rt::GuestType>::read(&#location)?; + let #name = <#type_ as wiggle::GuestType>::read(&#location)?; } } witx::Type::Pointer(pointee) | witx::Type::ConstPointer(pointee) => { - let pointee_type = names.type_ref(&pointee, anon_lifetime()); + let pointee_type = names::type_ref(&pointee, anon_lifetime()); quote! { - let #name = <#rt::GuestPtr::<#pointee_type> as #rt::GuestType>::read(&#location)?; + let #name = as wiggle::GuestType>::read(&#location)?; } } _ => unimplemented!("other anonymous struct members: {:?}", ty), @@ -69,10 +64,10 @@ pub(super) fn define_struct( }); let member_writes = s.member_layout().into_iter().map(|ml| { - let name = names.struct_member(&ml.member.name); + let name = names::struct_member(&ml.member.name); let offset = ml.offset as u32; quote! { - #rt::GuestType::write( + wiggle::GuestType::write( &location.cast::().add(#offset)?.cast(), val.#name, )?; @@ -91,7 +86,7 @@ pub(super) fn define_struct( #(#member_decls),* } - impl<'a> #rt::GuestType<'a> for #ident #struct_lifetime { + impl<'a> wiggle::GuestType<'a> for #ident #struct_lifetime { fn guest_size() -> u32 { #size } @@ -100,12 +95,12 @@ pub(super) fn define_struct( #align } - fn read(location: &#rt::GuestPtr<'a, Self>) -> Result { + fn read(location: &wiggle::GuestPtr<'a, Self>) -> Result { #(#member_reads)* Ok(#ident { #(#member_names),* }) } - fn write(location: &#rt::GuestPtr<'_, Self>, val: Self) -> Result<(), #rt::GuestError> { + fn write(location: &wiggle::GuestPtr<'_, Self>, val: Self) -> Result<(), wiggle::GuestError> { #(#member_writes)* Ok(()) } diff --git a/crates/wiggle/generate/src/types/variant.rs b/crates/wiggle/generate/src/types/variant.rs index 5624b43eb040..c16e3bc72d97 100644 --- a/crates/wiggle/generate/src/types/variant.rs +++ b/crates/wiggle/generate/src/types/variant.rs @@ -1,13 +1,16 @@ use crate::lifetimes::LifetimeExt; -use crate::names::Names; +use crate::names; use proc_macro2::{Literal, TokenStream}; use quote::quote; use witx::Layout; -pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) -> TokenStream { - let rt = names.runtime_mod(); - let ident = names.type_(name); +pub(super) fn define_variant( + name: &witx::Id, + v: &witx::Variant, + derive_std_error: bool, +) -> TokenStream { + let ident = names::type_(name); let size = v.mem_size_align().size as u32; let align = v.mem_size_align().align as usize; let contents_offset = v.payload_offset() as u32; @@ -16,9 +19,9 @@ pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) let tag_ty = super::int_repr_tokens(v.tag_repr); let variants = v.cases.iter().map(|c| { - let var_name = names.enum_variant(&c.name); + let var_name = names::enum_variant(&c.name); if let Some(tref) = &c.tref { - let var_type = names.type_ref(&tref, lifetime.clone()); + let var_type = names::type_ref(&tref, lifetime.clone()); quote!(#var_name(#var_type)) } else { quote!(#var_name) @@ -27,13 +30,13 @@ pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) let read_variant = v.cases.iter().enumerate().map(|(i, c)| { let i = Literal::usize_unsuffixed(i); - let variantname = names.enum_variant(&c.name); + let variantname = names::enum_variant(&c.name); if let Some(tref) = &c.tref { - let varianttype = names.type_ref(tref, lifetime.clone()); + let varianttype = names::type_ref(tref, lifetime.clone()); quote! { #i => { let variant_ptr = location.cast::().add(#contents_offset)?; - let variant_val = <#varianttype as #rt::GuestType>::read(&variant_ptr.cast())?; + let variant_val = <#varianttype as wiggle::GuestType>::read(&variant_ptr.cast())?; Ok(#ident::#variantname(variant_val)) } } @@ -43,17 +46,17 @@ pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) }); let write_variant = v.cases.iter().enumerate().map(|(i, c)| { - let variantname = names.enum_variant(&c.name); + let variantname = names::enum_variant(&c.name); let write_tag = quote! { location.cast().write(#i as #tag_ty)?; }; if let Some(tref) = &c.tref { - let varianttype = names.type_ref(tref, lifetime.clone()); + let varianttype = names::type_ref(tref, lifetime.clone()); quote! { #ident::#variantname(contents) => { #write_tag let variant_ptr = location.cast::().add(#contents_offset)?; - <#varianttype as #rt::GuestType>::write(&variant_ptr.cast(), contents)?; + <#varianttype as wiggle::GuestType>::write(&variant_ptr.cast(), contents)?; } } } else { @@ -68,26 +71,26 @@ pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) let mut extra_derive = quote!(); let enum_try_from = if v.cases.iter().all(|c| c.tref.is_none()) { let tryfrom_repr_cases = v.cases.iter().enumerate().map(|(i, c)| { - let variant_name = names.enum_variant(&c.name); + let variant_name = names::enum_variant(&c.name); let n = Literal::usize_unsuffixed(i); quote!(#n => Ok(#ident::#variant_name)) }); - let abi_ty = names.wasm_type(v.tag_repr.into()); + let abi_ty = names::wasm_type(v.tag_repr.into()); extra_derive = quote!(, Copy); quote! { impl TryFrom<#tag_ty> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #tag_ty) -> Result<#ident, #rt::GuestError> { + type Error = wiggle::GuestError; + fn try_from(value: #tag_ty) -> Result<#ident, wiggle::GuestError> { match value { #(#tryfrom_repr_cases),*, - _ => Err( #rt::GuestError::InvalidEnumValue(stringify!(#ident))), + _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))), } } } impl TryFrom<#abi_ty> for #ident { - type Error = #rt::GuestError; - fn try_from(value: #abi_ty) -> Result<#ident, #rt::GuestError> { + type Error = wiggle::GuestError; + fn try_from(value: #abi_ty) -> Result<#ident, wiggle::GuestError> { #ident::try_from(#tag_ty::try_from(value)?) } } @@ -98,7 +101,7 @@ pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) let enum_from = if v.cases.iter().all(|c| c.tref.is_none()) { let from_repr_cases = v.cases.iter().enumerate().map(|(i, c)| { - let variant_name = names.enum_variant(&c.name); + let variant_name = names::enum_variant(&c.name); let n = Literal::usize_unsuffixed(i); quote!(#ident::#variant_name => #n) }); @@ -121,16 +124,30 @@ pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) (quote!(), quote!(, PartialEq #extra_derive)) }; + let error_impls = if derive_std_error { + quote! { + impl std::fmt::Display for #ident { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } + } + impl std::error::Error for #ident {} + } + } else { + quote!() + }; + quote! { #[derive(Clone, Debug #extra_derive)] pub enum #ident #enum_lifetime { #(#variants),* } + #error_impls #enum_try_from #enum_from - impl<'a> #rt::GuestType<'a> for #ident #enum_lifetime { + impl<'a> wiggle::GuestType<'a> for #ident #enum_lifetime { fn guest_size() -> u32 { #size } @@ -139,19 +156,19 @@ pub(super) fn define_variant(names: &Names, name: &witx::Id, v: &witx::Variant) #align } - fn read(location: &#rt::GuestPtr<'a, Self>) - -> Result + fn read(location: &wiggle::GuestPtr<'a, Self>) + -> Result { let tag = location.cast::<#tag_ty>().read()?; match tag { #(#read_variant)* - _ => Err(#rt::GuestError::InvalidEnumValue(stringify!(#ident))), + _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))), } } - fn write(location: &#rt::GuestPtr<'_, Self>, val: Self) - -> Result<(), #rt::GuestError> + fn write(location: &wiggle::GuestPtr<'_, Self>, val: Self) + -> Result<(), wiggle::GuestError> { match val { #(#write_variant)* diff --git a/crates/wiggle/generate/src/wasmtime.rs b/crates/wiggle/generate/src/wasmtime.rs index 98c9f342ccad..fcab9da62ebc 100644 --- a/crates/wiggle/generate/src/wasmtime.rs +++ b/crates/wiggle/generate/src/wasmtime.rs @@ -1,17 +1,17 @@ use crate::config::Asyncness; use crate::funcs::func_bounds; -use crate::{CodegenSettings, Names}; +use crate::names; +use crate::CodegenSettings; use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote}; use std::collections::HashSet; pub fn link_module( module: &witx::Module, - names: &Names, target_path: Option<&syn::Path>, settings: &CodegenSettings, ) -> TokenStream { - let module_ident = names.module(&module.name); + let module_ident = names::module(&module.name); let send_bound = if settings.async_.contains_async(module) { quote! { + Send, T: Send } @@ -23,8 +23,8 @@ pub fn link_module( let mut bounds = HashSet::new(); for f in module.funcs() { let asyncness = settings.async_.get(module.name.as_str(), f.name.as_str()); - bodies.push(generate_func(&module, &f, names, target_path, asyncness)); - let bound = func_bounds(names, module, &f, settings); + bodies.push(generate_func(&module, &f, target_path, asyncness)); + let bound = func_bounds(module, &f, settings); for b in bound { bounds.insert(b); } @@ -46,14 +46,12 @@ pub fn link_module( format_ident!("add_{}_to_linker", module_ident) }; - let rt = names.runtime_mod(); - quote! { /// Adds all instance items to the specified `Linker`. pub fn #func_name( - linker: &mut #rt::wasmtime_crate::Linker, + linker: &mut wiggle::wasmtime_crate::Linker, get_cx: impl Fn(&mut T) -> &mut U + Send + Sync + Copy + 'static, - ) -> #rt::anyhow::Result<()> + ) -> wiggle::anyhow::Result<()> where U: #ctx_bound #send_bound { @@ -66,17 +64,14 @@ pub fn link_module( fn generate_func( module: &witx::Module, func: &witx::InterfaceFunc, - names: &Names, target_path: Option<&syn::Path>, asyncness: Asyncness, ) -> TokenStream { - let rt = names.runtime_mod(); - let module_str = module.name.as_str(); - let module_ident = names.module(&module.name); + let module_ident = names::module(&module.name); let field_str = func.name.as_str(); - let field_ident = names.func(&func.name); + let field_ident = names::func(&func.name); let (params, results) = func.wasm_signature(); @@ -88,14 +83,14 @@ fn generate_func( .enumerate() .map(|(i, ty)| { let name = &arg_names[i]; - let wasm = names.wasm_type(*ty); + let wasm = names::wasm_type(*ty); quote! { #name: #wasm } }) .collect::>(); let ret_ty = match results.len() { 0 => quote!(()), - 1 => names.wasm_type(results[0]), + 1 => names::wasm_type(results[0]), _ => unimplemented!(), }; @@ -114,16 +109,16 @@ fn generate_func( let body = quote! { let export = caller.get_export("memory"); let (mem, ctx) = match &export { - Some(#rt::wasmtime_crate::Extern::Memory(m)) => { + Some(wiggle::wasmtime_crate::Extern::Memory(m)) => { let (mem, ctx) = m.data_and_store_mut(&mut caller); let ctx = get_cx(ctx); - (#rt::wasmtime::WasmtimeGuestMemory::new(mem), ctx) + (wiggle::wasmtime::WasmtimeGuestMemory::new(mem), ctx) } - Some(#rt::wasmtime_crate::Extern::SharedMemory(m)) => { + Some(wiggle::wasmtime_crate::Extern::SharedMemory(m)) => { let ctx = get_cx(caller.data_mut()); - (#rt::wasmtime::WasmtimeGuestMemory::shared(m.data()), ctx) + (wiggle::wasmtime::WasmtimeGuestMemory::shared(m.data()), ctx) } - _ => #rt::anyhow::bail!("missing required memory export"), + _ => wiggle::anyhow::bail!("missing required memory export"), }; Ok(<#ret_ty>::from(#abi_func(ctx, &mem #(, #arg_names)*) #await_ ?)) }; @@ -135,7 +130,7 @@ fn generate_func( linker.#wrapper( #module_str, #field_str, - move |mut caller: #rt::wasmtime_crate::Caller<'_, T> #(, #arg_decls)*| { + move |mut caller: wiggle::wasmtime_crate::Caller<'_, T> #(, #arg_decls)*| { Box::new(async move { #body }) }, )?; @@ -147,9 +142,9 @@ fn generate_func( linker.func_wrap( #module_str, #field_str, - move |mut caller: #rt::wasmtime_crate::Caller<'_, T> #(, #arg_decls)*| -> #rt::anyhow::Result<#ret_ty> { + move |mut caller: wiggle::wasmtime_crate::Caller<'_, T> #(, #arg_decls)*| -> wiggle::anyhow::Result<#ret_ty> { let result = async { #body }; - #rt::run_in_dummy_executor(result)? + wiggle::run_in_dummy_executor(result)? }, )?; } @@ -160,7 +155,7 @@ fn generate_func( linker.func_wrap( #module_str, #field_str, - move |mut caller: #rt::wasmtime_crate::Caller<'_, T> #(, #arg_decls)*| -> #rt::anyhow::Result<#ret_ty> { + move |mut caller: wiggle::wasmtime_crate::Caller<'_, T> #(, #arg_decls)*| -> wiggle::anyhow::Result<#ret_ty> { #body }, )?; diff --git a/crates/wiggle/macro/src/lib.rs b/crates/wiggle/macro/src/lib.rs index 11bbbfc01d8b..9ea2cf104857 100644 --- a/crates/wiggle/macro/src/lib.rs +++ b/crates/wiggle/macro/src/lib.rs @@ -39,6 +39,10 @@ use syn::parse_macro_input; /// `{ errno => YourErrnoType }`. This allows you to use the `UserErrorConversion` /// trait to map these rich errors into the flat witx type, or to terminate /// WebAssembly execution by trapping. +/// * Instead of requiring the user to define an error type, wiggle can +/// generate an error type for the user which has conversions to/from +/// the base type, and permits trapping, using the syntax +/// `errno => trappable AnErrorType`. /// * Optional: `async` takes a set of witx modules and functions which are /// made Rust `async` functions in the module trait. /// @@ -146,7 +150,6 @@ pub fn from_witx(args: TokenStream) -> TokenStream { let config = parse_macro_input!(args as wiggle_generate::Config); let doc = config.load_document(); - let names = wiggle_generate::Names::new(quote!(wiggle)); let settings = wiggle_generate::CodegenSettings::new( &config.errors, @@ -157,9 +160,9 @@ pub fn from_witx(args: TokenStream) -> TokenStream { ) .expect("validating codegen settings"); - let code = wiggle_generate::generate(&doc, &names, &settings); + let code = wiggle_generate::generate(&doc, &settings); let metadata = if cfg!(feature = "wiggle_metadata") { - wiggle_generate::generate_metadata(&doc, &names) + wiggle_generate::generate_metadata(&doc) } else { quote!() }; @@ -188,7 +191,6 @@ pub fn async_trait(attr: TokenStream, item: TokenStream) -> TokenStream { pub fn wasmtime_integration(args: TokenStream) -> TokenStream { let config = parse_macro_input!(args as wiggle_generate::WasmtimeConfig); let doc = config.c.load_document(); - let names = wiggle_generate::Names::new(quote!(wiggle)); let settings = wiggle_generate::CodegenSettings::new( &config.c.errors, @@ -200,7 +202,7 @@ pub fn wasmtime_integration(args: TokenStream) -> TokenStream { .expect("validating codegen settings"); let modules = doc.modules().map(|module| { - wiggle_generate::wasmtime::link_module(&module, &names, Some(&config.target), &settings) + wiggle_generate::wasmtime::link_module(&module, Some(&config.target), &settings) }); quote!( #(#modules)* ).into() } diff --git a/crates/wiggle/tests/errors.rs b/crates/wiggle/tests/errors.rs index 737b7ea5a4b7..594937843003 100644 --- a/crates/wiggle/tests/errors.rs +++ b/crates/wiggle/tests/errors.rs @@ -24,34 +24,32 @@ mod convert_just_errno { (param $strike u32) (result $err (expected (error $errno))))) ", - errors: { errno => RichError }, + errors: { errno => trappable ErrnoT }, }); impl_errno!(types::Errno); - /// When the `errors` mapping in witx is non-empty, we need to impl the - /// types::UserErrorConversion trait that wiggle generates from that mapping. - impl<'a> types::UserErrorConversion for WasiCtx<'a> { - fn errno_from_rich_error(&mut self, e: RichError) -> Result { - // WasiCtx can collect a Vec log so we can test this. We're - // logging the Display impl that `thiserror::Error` provides us. - self.log.borrow_mut().push(e.to_string()); - // Then do the trivial mapping down to the flat enum. - match e { - RichError::InvalidArg { .. } => Ok(types::Errno::InvalidArg), - RichError::PicketLine { .. } => Ok(types::Errno::PicketLine), + impl From for types::ErrnoT { + fn from(rich: RichError) -> types::ErrnoT { + match rich { + RichError::InvalidArg(s) => { + types::ErrnoT::from(types::Errno::InvalidArg).context(s) + } + RichError::PicketLine(s) => { + types::ErrnoT::from(types::Errno::PicketLine).context(s) + } } } } impl<'a> one_error_conversion::OneErrorConversion for WasiCtx<'a> { - fn foo(&mut self, strike: u32) -> Result<(), RichError> { + fn foo(&mut self, strike: u32) -> Result<(), types::ErrnoT> { // We use the argument to this function to exercise all of the // possible error cases we could hit here match strike { 0 => Ok(()), - 1 => Err(RichError::PicketLine(format!("I'm not a scab"))), - _ => Err(RichError::InvalidArg(format!("out-of-bounds: {}", strike))), + 1 => Err(RichError::PicketLine(format!("I'm not a scab")))?, + _ => Err(RichError::InvalidArg(format!("out-of-bounds: {}", strike)))?, } } } @@ -78,11 +76,6 @@ mod convert_just_errno { types::Errno::PicketLine as i32, "Expected return value for strike=1" ); - assert_eq!( - ctx.log.borrow_mut().pop().expect("one log entry"), - "Won't cross picket line: I'm not a scab", - "Expected log entry for strike=1", - ); // Second error case: let r2 = one_error_conversion::foo(&mut ctx, &host_memory, 2).unwrap(); @@ -91,11 +84,6 @@ mod convert_just_errno { types::Errno::InvalidArg as i32, "Expected return value for strike=2" ); - assert_eq!( - ctx.log.borrow_mut().pop().expect("one log entry"), - "Invalid argument: out-of-bounds: 2", - "Expected log entry for strike=2", - ); } }