Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wiggle: new error configuration for generating a "trappable error" #5276

Merged
merged 2 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions crates/wiggle/generate/src/codegen_settings.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::config::{AsyncConf, ErrorConf, TracingConf};
use crate::config::{AsyncConf, ErrorConf, ErrorConfField, TracingConf};
use anyhow::{anyhow, Error};
use proc_macro2::TokenStream;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use std::collections::HashMap;
use std::rc::Rc;
Expand Down Expand Up @@ -39,7 +39,7 @@ impl CodegenSettings {
}

pub struct ErrorTransform {
m: Vec<UserErrorType>,
m: Vec<ErrorType>,
}

impl ErrorTransform {
Expand All @@ -49,7 +49,13 @@ impl ErrorTransform {
pub fn new(conf: &ErrorConf, doc: &Document) -> Result<Self, Error> {
let mut richtype_identifiers = HashMap::new();
let m = conf.iter().map(|(ident, field)|
if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) {
match field {
ErrorConfField::Trappable(field) => if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) {
Ok(ErrorType::Generated(TrappableErrorType { abi_type, rich_type: field.rich_error.clone() }))
} else {
Err(anyhow!("No witx typename \"{}\" found", ident.to_string()))
},
ErrorConfField::User(field) => if let Some(abi_type) = doc.typename(&Id::new(ident.to_string())) {
if let Some(ident) = field.rich_error.get_ident() {
if let Some(prior_def) = richtype_identifiers.insert(ident.clone(), field.err_loc.clone())
{
Expand All @@ -58,11 +64,11 @@ impl ErrorTransform {
ident, prior_def
));
}
Ok(UserErrorType {
Ok(ErrorType::User(UserErrorType {
abi_type,
rich_type: field.rich_error.clone(),
method_fragment: ident.to_string()
})
}))
} else {
return Err(anyhow!(
"rich error type must be identifier for now - TODO add ability to provide a corresponding identifier: {:?}",
Expand All @@ -71,23 +77,52 @@ impl ErrorTransform {
}
}
else { Err(anyhow!("No witx typename \"{}\" found", ident.to_string())) }
}
).collect::<Result<Vec<_>, Error>>()?;
Ok(Self { m })
}

pub fn iter(&self) -> impl Iterator<Item = &UserErrorType> {
pub fn iter(&self) -> impl Iterator<Item = &ErrorType> {
self.m.iter()
}

pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&UserErrorType> {
pub fn for_abi_error(&self, tref: &TypeRef) -> Option<&ErrorType> {
match tref {
TypeRef::Name(nt) => self.for_name(nt),
TypeRef::Value { .. } => None,
}
}

pub fn for_name(&self, nt: &NamedType) -> Option<&UserErrorType> {
self.m.iter().find(|u| u.abi_type.name == nt.name)
pub fn for_name(&self, nt: &NamedType) -> Option<&ErrorType> {
self.m.iter().find(|e| e.abi_type().name == nt.name)
}
}

pub enum ErrorType {
User(UserErrorType),
Generated(TrappableErrorType),
}
impl ErrorType {
pub fn abi_type(&self) -> &NamedType {
match self {
Self::User(u) => &u.abi_type,
Self::Generated(r) => &r.abi_type,
}
}
}

pub struct TrappableErrorType {
abi_type: Rc<NamedType>,
rich_type: Ident,
}

impl TrappableErrorType {
pub fn abi_type(&self) -> TypeRef {
TypeRef::Name(self.abi_type.clone())
}
pub fn typename(&self) -> TokenStream {
let richtype = &self.rich_type;
quote!(#richtype)
}
}

Expand Down
78 changes: 59 additions & 19 deletions crates/wiggle/generate/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod kw {
syn::custom_keyword!(wasmtime);
syn::custom_keyword!(tracing);
syn::custom_keyword!(disable_for);
syn::custom_keyword!(trappable);
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -274,14 +275,14 @@ impl Parse for ErrorConf {
content.parse_terminated(Parse::parse)?;
let mut m = HashMap::new();
for i in items {
match m.insert(i.abi_error.clone(), i.clone()) {
match m.insert(i.abi_error().clone(), i.clone()) {
None => {}
Some(prev_def) => {
return Err(Error::new(
i.err_loc,
*i.err_loc(),
format!(
"duplicate definition of rich error type for {:?}: previously defined at {:?}",
i.abi_error, prev_def.err_loc,
i.abi_error(), prev_def.err_loc(),
),
))
}
Expand All @@ -291,14 +292,67 @@ impl Parse for ErrorConf {
}
}

#[derive(Debug, Clone)]
pub enum ErrorConfField {
Trappable(TrappableErrorConfField),
User(UserErrorConfField),
}
impl ErrorConfField {
pub fn abi_error(&self) -> &Ident {
match self {
Self::Trappable(t) => &t.abi_error,
Self::User(u) => &u.abi_error,
}
}
pub fn err_loc(&self) -> &Span {
match self {
Self::Trappable(t) => &t.err_loc,
Self::User(u) => &u.err_loc,
}
}
}

impl Parse for ErrorConfField {
fn parse(input: ParseStream) -> Result<Self> {
let err_loc = input.span();
let abi_error = input.parse::<Ident>()?;
let _arrow: Token![=>] = input.parse()?;

let lookahead = input.lookahead1();
if lookahead.peek(kw::trappable) {
let _ = input.parse::<kw::trappable>()?;
let rich_error = input.parse()?;
Ok(ErrorConfField::Trappable(TrappableErrorConfField {
abi_error,
rich_error,
err_loc,
}))
} else {
let rich_error = input.parse::<syn::Path>()?;
Ok(ErrorConfField::User(UserErrorConfField {
abi_error,
rich_error,
err_loc,
}))
}
}
}

#[derive(Clone, Debug)]
pub struct TrappableErrorConfField {
pub abi_error: Ident,
pub rich_error: Ident,
pub err_loc: Span,
}

#[derive(Clone)]
pub struct ErrorConfField {
pub struct UserErrorConfField {
pub abi_error: Ident,
pub rich_error: syn::Path,
pub err_loc: Span,
}

impl std::fmt::Debug for ErrorConfField {
impl std::fmt::Debug for UserErrorConfField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ErrorConfField")
.field("abi_error", &self.abi_error)
Expand All @@ -308,20 +362,6 @@ impl std::fmt::Debug for ErrorConfField {
}
}

impl Parse for ErrorConfField {
fn parse(input: ParseStream) -> Result<Self> {
let err_loc = input.span();
let abi_error = input.parse::<Ident>()?;
let _arrow: Token![=>] = input.parse()?;
let rich_error = input.parse::<syn::Path>()?;
Ok(ErrorConfField {
abi_error,
rich_error,
err_loc,
})
}
}

#[derive(Clone, Default, Debug)]
/// Modules and funcs that have async signatures
pub struct AsyncConf {
Expand Down
Loading