diff --git a/docs/docs/comment_dsl.mdx b/docs/docs/comment_dsl.mdx index c92bd77..ac90aea 100644 --- a/docs/docs/comment_dsl.mdx +++ b/docs/docs/comment_dsl.mdx @@ -104,6 +104,63 @@ foo = uint ; @newtype @custom_json Avoids generating and/or deriving json-related traits under the assumption that the user will supply their own implementation to be used in the generated library. +## @custom_serialize / @custom_deserialize + +```cddl +custom_bytes = bytes ; @custom_serialize custom_serialize_bytes @custom_deserialize custom_deserialize_bytes + +struct_with_custom_serialization = [ + custom_bytes, + field: bytes, ; @custom_serialize custom_serialize_bytes @custom_deserialize custom_deserialize_bytes + overridden: custom_bytes, ; @custom_serialize write_hex_string @custom_deserialize read_hex_string + tagged1: #6.9(custom_bytes), + tagged2: #6.9(uint), ; @custom_serialize write_tagged_uint_str @custom_deserialize read_tagged_uint_str +] +``` + +This allows the overriding of serialization and/or deserialization for when a specific format must be maintained. This works even with primitives where _CDDL_CODEGEN_EXTERN_TYPE_ would require making a wrapper type to use. + +The string after `@custom_serialize`/`@custom_deserialize` will be directly called as a function in place of regular serialization/deserialization code. As such it must either be specified using fully qualified paths e.g. `@custom_serialize crate::utils::custom_serialize_function`, or post-generation it will need to be imported into the serialization code by hand e.g. adding `import crate::utils::custom_serialize_function;`. + +With `--preserve-encodings=true` the encoding variables must be passed in in the order they are used in cddl-codegen with regular serialization. They are passed in as `Option` for integers/tags, `LenEncoding` for lengths and `StringEncoding` for text/bytes. These are the same types as are stored in the `*Encoding` structs generated. The same must be returned for deserialization. When there are no encoding variables the deserialized value should be directly returned, and if not a tuple with the value and its encoding variables should be returned. + +There are two ways to use this comment DSL: + +* Type level: e.g. `custom_bytes`. This will replace the (de)serialization everywhere you use this type. +* Field level: e.g. `struct_with_custom_serialization.field`. This will entirely replace the (de)serialization logic for the entire field, including other encoding operations like tags, `.cbor`, etc. + +Example function signatures for `--preserve-encodings=false` for `custom_serialize_bytes` / `custom_deserialize_bytes` above: + +```rust +pub fn custom_serialize_bytes<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + bytes: &[u8], +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> + +pub fn custom_deserialize_bytes( + raw: &mut cbor_event::de::Deserializer, +) -> Result, DeserializeError> +``` + +Example function signatures for `--preserve-encodings=true` for `write_tagged_uint_str` / `read_tagged_uint_str` above: + +```rust +pub fn write_tagged_uint_str<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + uint: &u64, + tag_encoding: Option, + text_encoding: Option, +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> + +pub fn read_tagged_uint_str( + raw: &mut cbor_event::de::Deserializer, +) -> Result<(u64, Option, Option), DeserializeError> +``` + +Note that as this is at the field-level it must handle the tag as well as the `uint`. + +For more examples see `tests/custom_serialization` (used in the `core` and `core_no_wasm` tests) and `tests/custom_serialization_preserve` (used in the `preserve-encodings` test). + ## _CDDL_CODEGEN_EXTERN_TYPE_ While not as a comment, this allows you to compose in hand-written structs into a cddl spec. diff --git a/src/comment_ast.rs b/src/comment_ast.rs index aef250a..9efbaf9 100644 --- a/src/comment_ast.rs +++ b/src/comment_ast.rs @@ -6,13 +6,15 @@ use nom::{ IResult, }; -#[derive(Default, Debug, PartialEq)] +#[derive(Clone, Default, Debug, PartialEq)] pub struct RuleMetadata { pub name: Option, pub is_newtype: bool, pub no_alias: bool, pub used_as_key: bool, pub custom_json: bool, + pub custom_serialize: Option, + pub custom_deserialize: Option, } pub fn merge_metadata(r1: &RuleMetadata, r2: &RuleMetadata) -> RuleMetadata { @@ -28,6 +30,29 @@ pub fn merge_metadata(r1: &RuleMetadata, r2: &RuleMetadata) -> RuleMetadata { no_alias: r1.no_alias || r2.no_alias, used_as_key: r1.used_as_key || r2.used_as_key, custom_json: r1.custom_json || r2.custom_json, + custom_serialize: match (r1.custom_serialize.as_ref(), r2.custom_serialize.as_ref()) { + (Some(val1), Some(val2)) => { + panic!( + "Key \"custom_serialize\" specified twice: {:?} {:?}", + val1, val2 + ) + } + (val @ Some(_), _) => val.cloned(), + (_, val) => val.cloned(), + }, + custom_deserialize: match ( + r1.custom_deserialize.as_ref(), + r2.custom_deserialize.as_ref(), + ) { + (Some(val1), Some(val2)) => { + panic!( + "Key \"custom_deserialize\" specified twice: {:?} {:?}", + val1, val2 + ) + } + (val @ Some(_), _) => val.cloned(), + (_, val) => val.cloned(), + }, }; merged.verify(); merged @@ -39,6 +64,8 @@ enum ParseResult { DontGenAlias, UsedAsKey, CustomJson, + CustomSerialize(String), + CustomDeserialize(String), } impl RuleMetadata { @@ -67,6 +94,32 @@ impl RuleMetadata { ParseResult::CustomJson => { base.custom_json = true; } + ParseResult::CustomSerialize(custom_serialize) => { + match base.custom_serialize.as_ref() { + Some(old) => { + panic!( + "Key \"custom_serialize\" specified twice: {:?} {:?}", + old, custom_serialize + ) + } + None => { + base.custom_serialize = Some(custom_serialize.to_string()); + } + } + } + ParseResult::CustomDeserialize(custom_deserialize) => { + match base.custom_deserialize.as_ref() { + Some(old) => { + panic!( + "Key \"custom_deserialize\" specified twice: {:?} {:?}", + old, custom_deserialize + ) + } + None => { + base.custom_deserialize = Some(custom_deserialize.to_string()); + } + } + } } } base.verify(); @@ -113,6 +166,28 @@ fn tag_custom_json(input: &str) -> IResult<&str, ParseResult> { Ok((input, ParseResult::CustomJson)) } +fn tag_custom_serialize(input: &str) -> IResult<&str, ParseResult> { + let (input, _) = tag("@custom_serialize")(input)?; + let (input, _) = take_while(char::is_whitespace)(input)?; + let (input, custom_serialize) = take_while1(|ch| !char::is_whitespace(ch))(input)?; + + Ok(( + input, + ParseResult::CustomSerialize(custom_serialize.to_string()), + )) +} + +fn tag_custom_deserialize(input: &str) -> IResult<&str, ParseResult> { + let (input, _) = tag("@custom_deserialize")(input)?; + let (input, _) = take_while(char::is_whitespace)(input)?; + let (input, custom_deserialize) = take_while1(|ch| !char::is_whitespace(ch))(input)?; + + Ok(( + input, + ParseResult::CustomDeserialize(custom_deserialize.to_string()), + )) +} + fn whitespace_then_tag(input: &str) -> IResult<&str, ParseResult> { let (input, _) = take_while(char::is_whitespace)(input)?; let (input, result) = alt(( @@ -121,6 +196,8 @@ fn whitespace_then_tag(input: &str) -> IResult<&str, ParseResult> { tag_no_alias, tag_used_as_key, tag_custom_json, + tag_custom_serialize, + tag_custom_deserialize, ))(input)?; Ok((input, result)) @@ -163,6 +240,8 @@ fn parse_comment_name() { no_alias: false, used_as_key: false, custom_json: false, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -180,6 +259,8 @@ fn parse_comment_newtype() { no_alias: false, used_as_key: false, custom_json: false, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -197,6 +278,8 @@ fn parse_comment_newtype_and_name() { no_alias: false, used_as_key: false, custom_json: false, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -214,6 +297,8 @@ fn parse_comment_newtype_and_name_and_used_as_key() { no_alias: false, used_as_key: true, custom_json: false, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -231,6 +316,8 @@ fn parse_comment_used_as_key() { no_alias: false, used_as_key: true, custom_json: false, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -248,6 +335,8 @@ fn parse_comment_newtype_and_name_inverse() { no_alias: false, used_as_key: false, custom_json: false, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -265,6 +354,8 @@ fn parse_comment_name_noalias() { no_alias: true, used_as_key: false, custom_json: false, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -282,6 +373,8 @@ fn parse_comment_newtype_and_custom_json() { no_alias: false, used_as_key: false, custom_json: true, + custom_serialize: None, + custom_deserialize: None, } )) ); @@ -292,3 +385,42 @@ fn parse_comment_newtype_and_custom_json() { fn parse_comment_noalias_newtype() { let _ = rule_metadata("@no_alias @newtype"); } + +#[test] +fn parse_comment_custom_serialize_deserialize() { + assert_eq!( + rule_metadata("@custom_serialize foo @custom_deserialize bar"), + Ok(( + "", + RuleMetadata { + name: None, + is_newtype: false, + no_alias: false, + used_as_key: false, + custom_json: false, + custom_serialize: Some("foo".to_string()), + custom_deserialize: Some("bar".to_string()), + } + )) + ); +} + +// can't have all since @no_alias and @newtype are mutually exclusive +#[test] +fn parse_comment_all_except_no_alias() { + assert_eq!( + rule_metadata("@newtype @name baz @custom_serialize foo @custom_deserialize bar @used_as_key @custom_json"), + Ok(( + "", + RuleMetadata { + name: Some("baz".to_string()), + is_newtype: true, + no_alias: false, + used_as_key: true, + custom_json: true, + custom_serialize: Some("foo".to_string()), + custom_deserialize: Some("bar".to_string()), + } + )) + ); +} diff --git a/src/generation.rs b/src/generation.rs index 8ea0e0d..d28d31b 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -9,8 +9,8 @@ use std::process::{Command, Stdio}; use crate::intermediate::{ AliasIdent, CBOREncodingOperation, CDDLIdent, ConceptualRustType, EnumVariant, EnumVariantData, FixedValue, IntermediateTypes, ModuleScope, Primitive, Representation, RustField, RustIdent, - RustRecord, RustStructCBORLen, RustStructType, RustType, RustTypeSerializeConfig, - ToWasmBoundaryOperations, VariantIdent, ROOT_SCOPE, + RustRecord, RustStructCBORLen, RustStructConfig, RustStructType, RustType, + RustTypeSerializeConfig, ToWasmBoundaryOperations, VariantIdent, ROOT_SCOPE, }; use crate::utils::{cbor_type_code_str, convert_to_snake_case}; @@ -28,6 +28,8 @@ struct SerializeConfig<'a> { encoding_var_in_option_struct: Option, /// an overload instead of using "serializer". (name, is_local) - if is_local then &mut will be appended when needed. serializer_name_overload: Option<(&'a str, bool)>, + /// Override regular serialization lgoic with a call to this function + custom_serialize: Option, } impl<'a> SerializeConfig<'a> { @@ -40,6 +42,7 @@ impl<'a> SerializeConfig<'a> { encoding_var_is_ref: false, encoding_var_in_option_struct: None, serializer_name_overload: None, + custom_serialize: None, } } @@ -84,6 +87,11 @@ impl<'a> SerializeConfig<'a> { self } + fn custom_serialize(mut self, func: String) -> Self { + self.custom_serialize = Some(func); + self + } + fn encoding_var(&self, child: Option<&str>, is_copy: bool) -> String { let child_suffix = match child { Some(c) => format!("_{c}"), @@ -166,6 +174,8 @@ struct DeserializeConfig<'a> { deserializer_name_overload: Option<&'a str>, /// Overload for read_len. This would be a local e.g. for arrays read_len_overload: Option, + /// Override regular deserialization lgoic with a call to this function + custom_deserialize: Option, } impl<'a> DeserializeConfig<'a> { @@ -177,6 +187,7 @@ impl<'a> DeserializeConfig<'a> { final_exprs: Vec::new(), deserializer_name_overload: None, read_len_overload: None, + custom_deserialize: None, } } @@ -204,6 +215,11 @@ impl<'a> DeserializeConfig<'a> { self } + fn custom_deserialize(mut self, func: String) -> Self { + self.custom_deserialize = Some(func); + self + } + fn pass_read_len(&self) -> String { if let Some(overload) = &self.read_len_overload { // the ONLY way to have a name overload is if we have a local variable (e.g. arrays) @@ -604,7 +620,15 @@ impl GenerationScope { } match rust_struct.variant() { RustStructType::Record(record) => { - codegen_struct(self, types, rust_ident, rust_struct.tag(), record, cli); + codegen_struct( + self, + types, + rust_ident, + rust_struct.tag(), + record, + rust_struct.config(), + cli, + ); } RustStructType::Table { domain, range } => { if cli.wasm { @@ -646,6 +670,7 @@ impl GenerationScope { rust_ident, variants, rust_struct.tag(), + rust_struct.config(), cli, ); } @@ -656,20 +681,17 @@ impl GenerationScope { variants, *rep, rust_struct.tag(), + rust_struct.config(), cli, ), - RustStructType::Wrapper { - wrapped, - min_max, - custom_json, - } => match rust_struct.tag() { + RustStructType::Wrapper { wrapped, min_max } => match rust_struct.tag() { Some(tag) => generate_wrapper_struct( self, types, rust_ident, &wrapped.clone().tag(tag), *min_max, - *custom_json, + rust_struct.config(), cli, ), None => generate_wrapper_struct( @@ -678,7 +700,7 @@ impl GenerationScope { rust_ident, wrapped, *min_max, - *custom_json, + rust_struct.config(), cli, ), }, @@ -700,6 +722,7 @@ impl GenerationScope { rust_ident, variants, rust_struct.tag(), + rust_struct.config(), cli, ); } @@ -1177,12 +1200,8 @@ impl GenerationScope { .rust_structs() .iter() .any(|(_, rust_struct)| match &rust_struct.variant { - RustStructType::Wrapper { - wrapped, - custom_json, - .. - } => { - !custom_json + RustStructType::Wrapper { wrapped, .. } => { + !rust_struct.config().custom_json && matches!( wrapped.resolve_alias_shallow(), ConceptualRustType::Primitive(Primitive::Bytes) @@ -1409,453 +1428,546 @@ impl GenerationScope { let encoding_var_is_copy = serializing_rust_type.encoding_var_is_copy(types); let encoding_var = config.encoding_var(None, encoding_var_is_copy); let encoding_var_deref = format!("{encoding_deref}{encoding_var}"); - match serializing_rust_type { - SerializingRustType::EncodingOperation(CBOREncodingOperation::Tagged(tag), child) => { - let expr = format!("{tag}u64"); - write_using_sz( - body, - "write_tag", - serializer_use, - &expr, - &expr, - "?;", - &format!( - "{}{}", - encoding_deref, - config.encoding_var(Some("tag"), encoding_var_is_copy) - ), - cli, - ); - self.generate_serialize(types, *child, body, config, cli); - } - SerializingRustType::EncodingOperation(CBOREncodingOperation::CBORBytes, child) => { - let inner_se = format!("{}_inner_se", config.var_name); - body.line(&format!("let mut {inner_se} = Serializer::new_vec();")); - let inner_config = config - .clone() - .is_end(false) - .serializer_name_overload((&inner_se, true)); - self.generate_serialize(types, *child, body, inner_config, cli); - body.line(&format!( - "let {}_bytes = {}.finalize();", - config.var_name, inner_se - )); - write_string_sz( - body, - "write_bytes", - serializer_use, - &format!("{}_bytes", config.var_name), - true, - line_ender, - &config.encoding_var(Some("bytes"), encoding_var_is_copy), - cli, - ); - } - SerializingRustType::Root(ConceptualRustType::Fixed(value), _cfg) => match value { - FixedValue::Null => { - body.line(&format!( - "{serializer_use}.write_special(cbor_event::Special::Null){line_ender}" - )); - } - FixedValue::Bool(b) => { - body.line(&format!( - "{serializer_use}.write_special(cbor_event::Special::Bool({b})){line_ender}" - )); - } - FixedValue::Uint(u) => { - let expr = format!("{u}u64"); + // field-level @custom_serialize overrides everything + if let Some(custom_serialize) = &config.custom_serialize { + let pass_encoding_args = if cli.preserve_encodings { + Cow::Owned( + encoding_fields_impl(types, &config.var_name, serializing_rust_type, cli) + .into_iter() + .map(|enc| { + format!( + ", {}", + match &config.encoding_var_in_option_struct { + Some(namespace) => format!( + "{}{}.as_ref().map(|encs| encs.{}{}).unwrap_or_default()", + if enc.is_copy { "" } else { "&" }, + namespace, + enc.field_name, + if enc.is_copy { "" } else { ".clone()" }, + ), + None => enc.field_name.clone(), + } + ) + }) + .collect::>() + .join(""), + ) + } else { + Cow::Borrowed("") + }; + body.line(&format!( + "{}({}, {}{}{}){}", + custom_serialize, + serializer_use, + expr_ref, + pass_encoding_args, + canonical_param(cli), + line_ender + )); + } else { + match serializing_rust_type { + SerializingRustType::EncodingOperation( + CBOREncodingOperation::Tagged(tag), + child, + ) => { + let expr = format!("{tag}u64"); write_using_sz( body, - "write_unsigned_integer", + "write_tag", serializer_use, &expr, &expr, - line_ender, - &encoding_var_deref, + "?;", + &format!( + "{}{}", + encoding_deref, + config.encoding_var(Some("tag"), encoding_var_is_copy) + ), cli, ); + self.generate_serialize(types, *child, body, config, cli); } - FixedValue::Nint(i) => { - assert!(*i < 0); - if !cli.preserve_encodings - && isize::BITS >= i64::BITS - && *i <= i64::MIN as isize - { - // cbor_event's write_negative_integer doesn't support serializing i64::MIN (https://github.com/primetype/cbor_event/issues/9) - // we need to use the write_negative_integer_sz endpoint which does support it. - // the bits check is since the constant parsed by cddl might not even be able to - // be that small e.g. on 32-bit platforms in which case we're already working with garbage - let sz_str = if *i >= -24 { - "cbor_event::Sz::Inline" - } else if *i >= -0x1_00 { - "cbor_event::Sz::One" - } else if *i >= -0x1_00_00 { - "cbor_event::Sz::Two" - } else if *i >= -0x1_00_00_00_00 { - "cbor_event::Sz::Four" - } else { - "cbor_event::Sz::Eight" - }; - body.line(&format!( - "{serializer_use}.write_negative_integer_sz({i}i128, {sz_str}){line_ender}" - )); - } else { - write_using_sz( - body, - "write_negative_integer", - serializer_use, - &i.to_string(), - &format!("({i}i128 + 1).abs() as u64"), - line_ender, - &encoding_var_deref, - cli, - ); - } - } - FixedValue::Float(f) => { + SerializingRustType::EncodingOperation(CBOREncodingOperation::CBORBytes, child) => { + let inner_se = format!("{}_inner_se", config.var_name); + body.line(&format!("let mut {inner_se} = Serializer::new_vec();")); + let inner_config = config + .clone() + .is_end(false) + .serializer_name_overload((&inner_se, true)); + self.generate_serialize(types, *child, body, inner_config, cli); body.line(&format!( - "{serializer_use}.write_special(cbor_event::Special::Float({f})){line_ender}" + "let {}_bytes = {}.finalize();", + config.var_name, inner_se )); - } - FixedValue::Text(s) => { write_string_sz( body, - "write_text", + "write_bytes", serializer_use, - &format!("\"{s}\""), - false, + &format!("{}_bytes", config.var_name), + true, line_ender, - &encoding_var, + &config.encoding_var(Some("bytes"), encoding_var_is_copy), cli, ); } - }, - SerializingRustType::Root(ConceptualRustType::Primitive(primitive), _cfg) => { - match primitive { - Primitive::Bool => { - body.line(&format!( - "{serializer_use}.write_special(cbor_event::Special::Bool({expr_deref})){line_ender}" - )); - } - Primitive::F32 => { + SerializingRustType::Root(ConceptualRustType::Fixed(value), _cfg) => match value { + FixedValue::Null => { body.line(&format!( - "{serializer_use}.write_special(cbor_event::Special::Float({expr_deref} as f64)){line_ender}" + "{serializer_use}.write_special(cbor_event::Special::Null){line_ender}" )); } - Primitive::F64 => { + FixedValue::Bool(b) => { body.line(&format!( - "{serializer_use}.write_special(cbor_event::Special::Float({expr_deref})){line_ender}" + "{serializer_use}.write_special(cbor_event::Special::Bool({b})){line_ender}" )); } - Primitive::Bytes => { - write_string_sz( - body, - "write_bytes", - serializer_use, - &config.expr, - true, - line_ender, - &encoding_var, - cli, - ); - } - Primitive::Str => { - write_string_sz( - body, - "write_text", - serializer_use, - &config.expr, - true, - line_ender, - &encoding_var, - cli, - ); - } - Primitive::I8 | Primitive::I16 | Primitive::I32 | Primitive::I64 => { - let mut pos = Block::new(format!("if {expr_deref} >= 0")); - let expr_pos = format!("{expr_deref} as u64"); + FixedValue::Uint(u) => { + let expr = format!("{u}u64"); write_using_sz( - &mut pos, + body, "write_unsigned_integer", serializer_use, - &expr_pos, - &expr_pos, + &expr, + &expr, line_ender, &encoding_var_deref, cli, ); - body.push_block(pos); - let mut neg = Block::new("else"); - // only the _sz variants support i128, the other endpoint is i64 - let expr = if cli.preserve_encodings { - format!("{expr_deref} as i128") - } else { - format!("{expr_deref} as i64") - }; - if !cli.preserve_encodings && *primitive == Primitive::I64 { - // https://github.com/primetype/cbor_event/issues/9 - // cbor_event doesn't support i64::MIN on write_negative_integer() so we use write_negative_integer_sz() for i64s - // even when not preserving encodings - neg.line(&format!("{serializer_use}.write_negative_integer_sz({expr_deref} as i128, cbor_event::Sz::canonical(({expr_deref} + 1).abs() as u64)){line_ender}")); + } + FixedValue::Nint(i) => { + assert!(*i < 0); + if !cli.preserve_encodings + && isize::BITS >= i64::BITS + && *i <= i64::MIN as isize + { + // cbor_event's write_negative_integer doesn't support serializing i64::MIN (https://github.com/primetype/cbor_event/issues/9) + // we need to use the write_negative_integer_sz endpoint which does support it. + // the bits check is since the constant parsed by cddl might not even be able to + // be that small e.g. on 32-bit platforms in which case we're already working with garbage + let sz_str = if *i >= -24 { + "cbor_event::Sz::Inline" + } else if *i >= -0x1_00 { + "cbor_event::Sz::One" + } else if *i >= -0x1_00_00 { + "cbor_event::Sz::Two" + } else if *i >= -0x1_00_00_00_00 { + "cbor_event::Sz::Four" + } else { + "cbor_event::Sz::Eight" + }; + body.line(&format!( + "{serializer_use}.write_negative_integer_sz({i}i128, {sz_str}){line_ender}" + )); } else { write_using_sz( - &mut neg, + body, "write_negative_integer", serializer_use, - &expr, - &format!("({expr_deref} + 1).abs() as u64"), + &i.to_string(), + &format!("({i}i128 + 1).abs() as u64"), line_ender, &encoding_var_deref, cli, ); } - body.push_block(neg); } - Primitive::U8 | Primitive::U16 | Primitive::U32 => { - let expr = format!("{expr_deref} as u64"); - write_using_sz( - body, - "write_unsigned_integer", - serializer_use, - &expr, - &expr, - line_ender, - &encoding_var_deref, - cli, - ); + FixedValue::Float(f) => { + body.line(&format!( + "{serializer_use}.write_special(cbor_event::Special::Float({f})){line_ender}" + )); } - Primitive::U64 => { - write_using_sz( + FixedValue::Text(s) => { + write_string_sz( body, - "write_unsigned_integer", + "write_text", serializer_use, - &expr_deref, - &expr_deref, + &format!("\"{s}\""), + false, line_ender, - &encoding_var_deref, + &encoding_var, cli, ); } - Primitive::N64 => { - if cli.preserve_encodings { + }, + SerializingRustType::Root(ConceptualRustType::Primitive(primitive), _cfg) => { + match primitive { + Primitive::Bool => { + body.line(&format!( + "{serializer_use}.write_special(cbor_event::Special::Bool({expr_deref})){line_ender}" + )); + } + Primitive::F32 => { + body.line(&format!( + "{serializer_use}.write_special(cbor_event::Special::Float({expr_deref} as f64)){line_ender}" + )); + } + Primitive::F64 => { + body.line(&format!( + "{serializer_use}.write_special(cbor_event::Special::Float({expr_deref})){line_ender}" + )); + } + Primitive::Bytes => { + write_string_sz( + body, + "write_bytes", + serializer_use, + &config.expr, + true, + line_ender, + &encoding_var, + cli, + ); + } + Primitive::Str => { + write_string_sz( + body, + "write_text", + serializer_use, + &config.expr, + true, + line_ender, + &encoding_var, + cli, + ); + } + Primitive::I8 | Primitive::I16 | Primitive::I32 | Primitive::I64 => { + let mut pos = Block::new(format!("if {expr_deref} >= 0")); + let expr_pos = format!("{expr_deref} as u64"); + write_using_sz( + &mut pos, + "write_unsigned_integer", + serializer_use, + &expr_pos, + &expr_pos, + line_ender, + &encoding_var_deref, + cli, + ); + body.push_block(pos); + let mut neg = Block::new("else"); + // only the _sz variants support i128, the other endpoint is i64 + let expr = if cli.preserve_encodings { + format!("{expr_deref} as i128") + } else { + format!("{expr_deref} as i64") + }; + if !cli.preserve_encodings && *primitive == Primitive::I64 { + // https://github.com/primetype/cbor_event/issues/9 + // cbor_event doesn't support i64::MIN on write_negative_integer() so we use write_negative_integer_sz() for i64s + // even when not preserving encodings + neg.line(&format!("{serializer_use}.write_negative_integer_sz({expr_deref} as i128, cbor_event::Sz::canonical(({expr_deref} + 1).abs() as u64)){line_ender}")); + } else { + write_using_sz( + &mut neg, + "write_negative_integer", + serializer_use, + &expr, + &format!("({expr_deref} + 1).abs() as u64"), + line_ender, + &encoding_var_deref, + cli, + ); + } + body.push_block(neg); + } + Primitive::U8 | Primitive::U16 | Primitive::U32 => { + let expr = format!("{expr_deref} as u64"); write_using_sz( body, - "write_negative_integer", + "write_unsigned_integer", + serializer_use, + &expr, + &expr, + line_ender, + &encoding_var_deref, + cli, + ); + } + Primitive::U64 => { + write_using_sz( + body, + "write_unsigned_integer", serializer_use, - &format!("-({expr_deref} as i128 + 1)"), + &expr_deref, &expr_deref, line_ender, &encoding_var_deref, cli, ); - } else { - // https://github.com/primetype/cbor_event/issues/9 - // cbor_event doesn't support i64::MIN on write_negative_integer() so we use write_negative_integer_sz() - // even when not preserving encodings - body.line(&format!("{serializer_use}.write_negative_integer_sz(-({expr_deref} as i128 + 1), cbor_event::Sz::canonical({expr_deref})){line_ender}")); + } + Primitive::N64 => { + if cli.preserve_encodings { + write_using_sz( + body, + "write_negative_integer", + serializer_use, + &format!("-({expr_deref} as i128 + 1)"), + &expr_deref, + line_ender, + &encoding_var_deref, + cli, + ); + } else { + // https://github.com/primetype/cbor_event/issues/9 + // cbor_event doesn't support i64::MIN on write_negative_integer() so we use write_negative_integer_sz() + // even when not preserving encodings + body.line(&format!("{serializer_use}.write_negative_integer_sz(-({expr_deref} as i128 + 1), cbor_event::Sz::canonical({expr_deref})){line_ender}")); + } } } } - } - SerializingRustType::Root(ConceptualRustType::Rust(t), type_cfg) => { - match &types.rust_struct(t).unwrap().variant() { - RustStructType::CStyleEnum { variants } => { - let mut enum_body = Block::new(format!("match {expr_ref}")); - for variant in variants { - let mut variant_match = - Block::new(format!("{}::{} =>", t, variant.name)); - self.generate_serialize( - types, - (variant.rust_type()).into(), - &mut variant_match, - config.clone().is_end(true), + SerializingRustType::Root(ConceptualRustType::Rust(t), type_cfg) => { + match &types.rust_struct(t).unwrap().variant() { + RustStructType::CStyleEnum { variants } => { + let mut enum_body = Block::new(format!("match {expr_ref}")); + for variant in variants { + let mut variant_match = + Block::new(format!("{}::{} =>", t, variant.name)); + self.generate_serialize( + types, + (variant.rust_type()).into(), + &mut variant_match, + config.clone().is_end(true), + cli, + ); + enum_body.push_block(variant_match); + } + if !config.is_end { + enum_body.after("?;"); + } + body.push_block(enum_body); + } + RustStructType::RawBytesType => { + write_string_sz( + body, + "write_bytes", + serializer_use, + &format!("{}.to_raw_bytes()", config.expr), + false, + line_ender, + &config.encoding_var(None, false), cli, ); - enum_body.push_block(variant_match); - } - if !config.is_end { - enum_body.after("?;"); } - body.push_block(enum_body); - } - RustStructType::RawBytesType => { - write_string_sz( - body, - "write_bytes", - serializer_use, - &format!("{}.to_raw_bytes()", config.expr), - false, - line_ender, - &config.encoding_var(None, false), - cli, - ); - } - _ => { - if types.is_plain_group(t) && !type_cfg.basic_override { - body.line(&format!( - "{}.serialize_as_embedded_group({}{}){}", - config.expr, - serializer_pass, - canonical_param(cli), - line_ender - )); - } else { - body.line(&format!( - "{}.serialize({}{}){}", - config.expr, - serializer_pass, - canonical_param(cli), - line_ender - )); + _ => { + if types.is_plain_group(t) && !type_cfg.basic_override { + body.line(&format!( + "{}.serialize_as_embedded_group({}{}){}", + config.expr, + serializer_pass, + canonical_param(cli), + line_ender + )); + } else { + body.line(&format!( + "{}.serialize({}{}){}", + config.expr, + serializer_pass, + canonical_param(cli), + line_ender + )); + } } } } - } - SerializingRustType::Root(ConceptualRustType::Array(ty), _cfg) => { - let len_expr = match &ty.conceptual_type { - ConceptualRustType::Rust(elem_ident) if types.is_plain_group(elem_ident) => { - // you should not be able to indiscriminately encode a plain group like this as it - // could be multiple elements. This would require special handling if it's even permitted in CDDL. - assert!(ty.encodings.is_empty()); - if let Some(fixed_elem_size) = ty.expanded_field_count(types) { - format!("{} * {}.len() as u64", fixed_elem_size, config.expr) - } else { - format!( - "{}.iter().map(|e| {}).sum()", - config.expr, - ty.definite_info("e", true, types, cli) - ) + SerializingRustType::Root(ConceptualRustType::Array(ty), _cfg) => { + let len_expr = match &ty.conceptual_type { + ConceptualRustType::Rust(elem_ident) + if types.is_plain_group(elem_ident) => + { + // you should not be able to indiscriminately encode a plain group like this as it + // could be multiple elements. This would require special handling if it's even permitted in CDDL. + assert!(ty.encodings.is_empty()); + if let Some(fixed_elem_size) = ty.expanded_field_count(types) { + format!("{} * {}.len() as u64", fixed_elem_size, config.expr) + } else { + format!( + "{}.iter().map(|e| {}).sum()", + config.expr, + ty.definite_info("e", true, types, cli) + ) + } } - } - _ => format!("{}.len() as u64", config.expr), - }; - start_len( - body, - Representation::Array, - serializer_use, - &encoding_var, - &len_expr, - cli, - ); - let elem_var_name = format!("{}_elem", config.var_name); - let elem_encs = if cli.preserve_encodings { - encoding_fields( - types, - &elem_var_name, - &ty.clone().resolve_aliases(), - false, - cli, - ) - } else { - vec![] - }; - let mut loop_block = if !elem_encs.is_empty() { - let mut block = Block::new(format!( - "for (i, element) in {}.iter().enumerate()", - config.expr - )); - block.line(config.container_encoding_lookup("elem", &elem_encs, "i")); - block - } else { - Block::new(format!("for element in {}.iter()", config.expr)) - }; - let elem_config = config - .clone() - .expr("element") - .expr_is_ref(true) - .var_name(elem_var_name) - .is_end(false) - .encoding_var_no_option_struct() - .encoding_var_is_ref(false); - self.generate_serialize(types, (&**ty).into(), &mut loop_block, elem_config, cli); - body.push_block(loop_block); - end_len(body, serializer_use, &encoding_var, config.is_end, cli); - } - SerializingRustType::Root(ConceptualRustType::Map(key, value), _cfg) => { - start_len( - body, - Representation::Map, - serializer_use, - &encoding_var, - &format!("{}.len() as u64", config.expr), - cli, - ); - let ser_loop = if cli.preserve_encodings { - let key_enc_fields = encoding_fields( - types, - &format!("{}_key", config.var_name), - &key.clone().resolve_aliases(), - false, - cli, - ); - let value_enc_fields = encoding_fields( - types, - &format!("{}_value", config.var_name), - &value.clone().resolve_aliases(), - false, + _ => format!("{}.len() as u64", config.expr), + }; + start_len( + body, + Representation::Array, + serializer_use, + &encoding_var, + &len_expr, cli, ); - let mut ser_loop = if cli.canonical_form { - let mut key_order = Block::new(format!( - "let mut key_order = {}.iter().map(|(k, v)|", - config.expr - )); - key_order.line("let mut buf = cbor_event::se::Serializer::new_vec();"); - if !key_enc_fields.is_empty() { - key_order.line(config.container_encoding_lookup( - "key", - &key_enc_fields, - "k", - )); - } - let key_config = - SerializeConfig::new("k", format!("{}_key", config.var_name)) + let elem_var_name = format!("{}_elem", config.var_name); + let elem_encs = if cli.preserve_encodings { + encoding_fields( + types, + &elem_var_name, + &ty.clone().resolve_aliases(), + false, + cli, + ) + } else { + vec![] + }; + let mut loop_block = if !elem_encs.is_empty() { + let mut block = Block::new(format!( + "for (i, element) in {}.iter().enumerate()", + config.expr + )); + block.line(config.container_encoding_lookup("elem", &elem_encs, "i")); + block + } else { + Block::new(format!("for element in {}.iter()", config.expr)) + }; + let elem_config = config + .clone() + .expr("element") + .expr_is_ref(true) + .var_name(elem_var_name) + .is_end(false) + .encoding_var_no_option_struct() + .encoding_var_is_ref(false); + self.generate_serialize( + types, + (&**ty).into(), + &mut loop_block, + elem_config, + cli, + ); + body.push_block(loop_block); + end_len(body, serializer_use, &encoding_var, config.is_end, cli); + } + SerializingRustType::Root(ConceptualRustType::Map(key, value), _cfg) => { + start_len( + body, + Representation::Map, + serializer_use, + &encoding_var, + &format!("{}.len() as u64", config.expr), + cli, + ); + let ser_loop = if cli.preserve_encodings { + let key_enc_fields = encoding_fields( + types, + &format!("{}_key", config.var_name), + &key.clone().resolve_aliases(), + false, + cli, + ); + let value_enc_fields = encoding_fields( + types, + &format!("{}_value", config.var_name), + &value.clone().resolve_aliases(), + false, + cli, + ); + let mut ser_loop = if cli.canonical_form { + let mut key_order = Block::new(format!( + "let mut key_order = {}.iter().map(|(k, v)|", + config.expr + )); + key_order.line("let mut buf = cbor_event::se::Serializer::new_vec();"); + if !key_enc_fields.is_empty() { + key_order.line(config.container_encoding_lookup( + "key", + &key_enc_fields, + "k", + )); + } + let key_config = + SerializeConfig::new("k", format!("{}_key", config.var_name)) + .expr_is_ref(true) + .is_end(false) + .serializer_name_overload(("buf", true)) + .encoding_var_is_ref(false); + self.generate_serialize( + types, + (&**key).into(), + &mut key_order, + key_config, + cli, + ); + key_order.line("Ok((buf.finalize(), k, v))").after( + ").collect::, &_, &_)>, cbor_event::Error>>()?;", + ); + body.push_block(key_order); + let mut key_order_if = Block::new("if force_canonical"); + let mut key_order_sort = Block::new( + "key_order.sort_by(|(lhs_bytes, _, _), (rhs_bytes, _, _)|", + ); + let mut key_order_sort_match = + Block::new("match lhs_bytes.len().cmp(&rhs_bytes.len())"); + key_order_sort_match + .line("std::cmp::Ordering::Equal => lhs_bytes.cmp(rhs_bytes),") + .line("diff_ord => diff_ord,"); + key_order_sort.push_block(key_order_sort_match).after(");"); + key_order_if.push_block(key_order_sort); + body.push_block(key_order_if); + let key_loop_var = if value_enc_fields.is_empty() { + "_key" + } else { + "key" + }; + let mut ser_loop = Block::new(format!( + "for (key_bytes, {key_loop_var}, value) in key_order" + )); + ser_loop + .line(format!("{serializer_use}.write_raw_bytes(&key_bytes)?;")); + ser_loop + } else { + let mut ser_loop = + Block::new(format!("for (key, value) in {}.iter()", config.expr)); + if !key_enc_fields.is_empty() { + ser_loop.line(config.container_encoding_lookup( + "key", + &key_enc_fields, + "key", + )); + } + let key_config = config + .clone() + .expr("key") .expr_is_ref(true) + .var_name(format!("{}_key", config.var_name)) .is_end(false) - .serializer_name_overload(("buf", true)) + .encoding_var_no_option_struct() .encoding_var_is_ref(false); + self.generate_serialize( + types, + (&**key).into(), + &mut ser_loop, + key_config, + cli, + ); + ser_loop + }; + if !value_enc_fields.is_empty() { + ser_loop.line(config.container_encoding_lookup( + "value", + &value_enc_fields, + "key", + )); + } + let value_config = config + .clone() + .expr("value") + .expr_is_ref(true) + .var_name(format!("{}_value", config.var_name)) + .is_end(false) + .encoding_var_no_option_struct() + .encoding_var_is_ref(false); self.generate_serialize( types, - (&**key).into(), - &mut key_order, - key_config, + (&**value).into(), + &mut ser_loop, + value_config, cli, ); - key_order.line("Ok((buf.finalize(), k, v))").after( - ").collect::, &_, &_)>, cbor_event::Error>>()?;", - ); - body.push_block(key_order); - let mut key_order_if = Block::new("if force_canonical"); - let mut key_order_sort = - Block::new("key_order.sort_by(|(lhs_bytes, _, _), (rhs_bytes, _, _)|"); - let mut key_order_sort_match = - Block::new("match lhs_bytes.len().cmp(&rhs_bytes.len())"); - key_order_sort_match - .line("std::cmp::Ordering::Equal => lhs_bytes.cmp(rhs_bytes),") - .line("diff_ord => diff_ord,"); - key_order_sort.push_block(key_order_sort_match).after(");"); - key_order_if.push_block(key_order_sort); - body.push_block(key_order_if); - let key_loop_var = if value_enc_fields.is_empty() { - "_key" - } else { - "key" - }; - let mut ser_loop = Block::new(format!( - "for (key_bytes, {key_loop_var}, value) in key_order" - )); - ser_loop.line(format!("{serializer_use}.write_raw_bytes(&key_bytes)?;")); ser_loop } else { let mut ser_loop = Block::new(format!("for (key, value) in {}.iter()", config.expr)); - if !key_enc_fields.is_empty() { - ser_loop.line(config.container_encoding_lookup( - "key", - &key_enc_fields, - "key", - )); - } let key_config = config .clone() .expr("key") @@ -1864,6 +1976,10 @@ impl GenerationScope { .is_end(false) .encoding_var_no_option_struct() .encoding_var_is_ref(false); + let value_config = key_config + .clone() + .expr("value") + .var_name(format!("{}_value", config.var_name)); self.generate_serialize( types, (&**key).into(), @@ -1871,80 +1987,58 @@ impl GenerationScope { key_config, cli, ); + self.generate_serialize( + types, + (&**value).into(), + &mut ser_loop, + value_config, + cli, + ); ser_loop }; - if !value_enc_fields.is_empty() { - ser_loop.line(config.container_encoding_lookup( - "value", - &value_enc_fields, - "key", - )); - } - let value_config = config - .clone() - .expr("value") - .expr_is_ref(true) - .var_name(format!("{}_value", config.var_name)) - .is_end(false) - .encoding_var_no_option_struct() - .encoding_var_is_ref(false); - self.generate_serialize( - types, - (&**value).into(), - &mut ser_loop, - value_config, - cli, - ); - ser_loop - } else { - let mut ser_loop = - Block::new(format!("for (key, value) in {}.iter()", config.expr)); - let key_config = config - .clone() - .expr("key") - .expr_is_ref(true) - .var_name(format!("{}_key", config.var_name)) - .is_end(false) - .encoding_var_no_option_struct() - .encoding_var_is_ref(false); - let value_config = key_config - .clone() - .expr("value") - .var_name(format!("{}_value", config.var_name)); - self.generate_serialize(types, (&**key).into(), &mut ser_loop, key_config, cli); + body.push_block(ser_loop); + end_len(body, serializer_use, &encoding_var, config.is_end, cli); + } + SerializingRustType::Root(ConceptualRustType::Optional(ty), _cfg) => { + let mut opt_block = Block::new(format!("match {expr_ref}")); + // TODO: do this in one line without a block if possible somehow. + // see other comment in generate_enum() + let mut some_block = Block::new("Some(x) =>"); + let opt_config = config.clone().expr("x").expr_is_ref(true).is_end(true); self.generate_serialize( types, - (&**value).into(), - &mut ser_loop, - value_config, + (&**ty).into(), + &mut some_block, + opt_config, cli, ); - ser_loop - }; - body.push_block(ser_loop); - end_len(body, serializer_use, &encoding_var, config.is_end, cli); - } - SerializingRustType::Root(ConceptualRustType::Optional(ty), _cfg) => { - let mut opt_block = Block::new(format!("match {expr_ref}")); - // TODO: do this in one line without a block if possible somehow. - // see other comment in generate_enum() - let mut some_block = Block::new("Some(x) =>"); - let opt_config = config.clone().expr("x").expr_is_ref(true).is_end(true); - self.generate_serialize(types, (&**ty).into(), &mut some_block, opt_config, cli); - some_block.after(","); - opt_block.push_block(some_block); - opt_block.line(&format!( - "None => {serializer_use}.write_special(cbor_event::Special::Null)," - )); - if !config.is_end { - opt_block.after("?;"); + some_block.after(","); + opt_block.push_block(some_block); + opt_block.line(&format!( + "None => {serializer_use}.write_special(cbor_event::Special::Null)," + )); + if !config.is_end { + opt_block.after("?;"); + } + body.push_block(opt_block); } - body.push_block(opt_block); - } - SerializingRustType::Root(ConceptualRustType::Alias(_ident, ty), _cfg) => { - self.generate_serialize(types, (&**ty).into(), body, config, cli) - } - }; + SerializingRustType::Root(ConceptualRustType::Alias(ident, ty), _cfg) => { + let config_for_alias = if let Some(custom_serialize) = types + .type_aliases() + .get(ident) + .unwrap() + .rule_metadata + .as_ref() + .and_then(|rmd| rmd.custom_serialize.clone()) + { + config.custom_serialize(custom_serialize) + } else { + config + }; + self.generate_serialize(types, (&**ty).into(), body, config_for_alias, cli) + } + }; + } } /// Generates a DeserializationCode to serialize {serializing_rust_type} using the context in {before_after} @@ -2003,326 +2097,368 @@ impl GenerationScope { } }; let deserializer_name = config.deserializer_name(); - match serializing_rust_type { - SerializingRustType::Root(ConceptualRustType::Fixed(f), _cfg) => { - if !cli.preserve_encodings { - // we don't evaluate to any values here, just verify - // before/after are ignored and we need to handle fixed value deserialization in a different way - // than normal ones. - assert_eq!(before_after.after, ""); - assert_eq!(before_after.before, ""); - } - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.throws = true; - deser_code.read_len_used = true; - } - match f { - FixedValue::Null => { - let mut special_block = Block::new(format!( - "if {deserializer_name}.special()? != cbor_event::Special::Null" - )); - special_block.line("return Err(DeserializeFailure::ExpectedNull.into());"); - deser_code.content.push_block(special_block); - if cli.preserve_encodings { - deser_code.content.line(&format!( - "{}{}{}", - before_after.before_str(false), - final_expr(config.final_exprs, None), - before_after.after_str(false) - )); - } + // field-level @custom_deserialize overrides everything + if let Some(custom_deserialize) = &config.custom_deserialize { + let deser_err_map = if !config.final_exprs.is_empty() { + let enc_fields = + encoding_fields_impl(types, config.var_name, serializing_rust_type, cli); + let (closure_args, tuple_fields) = if enc_fields.is_empty() { + (config.var_name.to_owned(), "".to_owned()) + } else { + let enc_fields_names = enc_fields + .iter() + .map(|enc| enc.field_name.clone()) + .collect::>() + .join(", "); + ( + format!("({}, {})", config.var_name, enc_fields_names), + enc_fields_names, + ) + }; + Cow::Owned(format!( + ".map(|{}| ({}, {}, {}))", + closure_args, + config.var_name, + config.final_exprs.join(", "), + tuple_fields + )) + } else { + Cow::Borrowed("") + }; + deser_code.content.line(&format!( + "{}{}({}){}{}", + before_after.before_str(true), + custom_deserialize, + deserializer_name, + deser_err_map, + before_after.after_str(true), + )); + } else { + match serializing_rust_type { + SerializingRustType::Root(ConceptualRustType::Fixed(f), _cfg) => { + if !cli.preserve_encodings { + // we don't evaluate to any values here, just verify + // before/after are ignored and we need to handle fixed value deserialization in a different way + // than normal ones. + assert_eq!(before_after.after, ""); + assert_eq!(before_after.before, ""); } - FixedValue::Uint(x) => { - if cli.preserve_encodings { - deser_code.content.line(&format!( - "let ({}_value, {}_encoding) = {}.unsigned_integer_sz()?;", - config.var_name, config.var_name, deserializer_name - )); - } else { - deser_code.content.line(&format!( - "let {}_value = {}.unsigned_integer()?;", - config.var_name, deserializer_name - )); - } - let mut compare_block = - Block::new(format!("if {}_value != {}", config.var_name, x)); - compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Uint({}_value), expected: Key::Uint({}) }}.into());", config.var_name, x)); - deser_code.content.push_block(compare_block); - if cli.preserve_encodings { - config - .final_exprs - .push(format!("Some({}_encoding)", config.var_name)); - deser_code.content.line(&format!( - "{}{}{}", - before_after.before_str(false), - final_expr(config.final_exprs, None), - before_after.after_str(false) - )); - //body.line(&format!("{}{}{}_encoding{}{}", before, sp, var_name, ep, after)); - } + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.throws = true; + deser_code.read_len_used = true; } - FixedValue::Nint(x) => { - if cli.preserve_encodings { - deser_code.content.line(&format!( - "let ({}_value, {}_encoding) = {}.negative_integer_sz()?;", - config.var_name, config.var_name, deserializer_name - )); - } else { - // we use the _sz variant here too to get around imcomplete nint support in the regular negative_integer() - deser_code.content.line(&format!( - "let ({}_value, _) = {}.negative_integer_sz()?;", - config.var_name, deserializer_name - )); - } - let x_abs = (x + 1).abs(); - let mut compare_block = - Block::new(format!("if {}_value != {}", config.var_name, x)); - compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Uint(({}_value + 1).abs() as u64), expected: Key::Uint({}) }}.into());", config.var_name, x_abs)); - deser_code.content.push_block(compare_block); - if cli.preserve_encodings { - config - .final_exprs - .push(format!("Some({}_encoding)", config.var_name)); - deser_code.content.line(&format!( - "{}{}{}", - before_after.before_str(false), - final_expr(config.final_exprs, None), - before_after.after_str(false) + match f { + FixedValue::Null => { + let mut special_block = Block::new(format!( + "if {deserializer_name}.special()? != cbor_event::Special::Null" )); - //body.line(&format!("{}{}{}_encoding{}{}", before, sp, var_name, ep, after)); + special_block + .line("return Err(DeserializeFailure::ExpectedNull.into());"); + deser_code.content.push_block(special_block); + if cli.preserve_encodings { + deser_code.content.line(&format!( + "{}{}{}", + before_after.before_str(false), + final_expr(config.final_exprs, None), + before_after.after_str(false) + )); + } } - } - FixedValue::Text(x) => { - if cli.preserve_encodings { - deser_code.content.line(&format!( - "let ({}_value, {}_encoding) = {}.text_sz()?;", - config.var_name, config.var_name, deserializer_name - )); - } else { - deser_code.content.line(&format!( - "let {}_value = {}.text()?;", - config.var_name, deserializer_name - )); + FixedValue::Uint(x) => { + if cli.preserve_encodings { + deser_code.content.line(&format!( + "let ({}_value, {}_encoding) = {}.unsigned_integer_sz()?;", + config.var_name, config.var_name, deserializer_name + )); + } else { + deser_code.content.line(&format!( + "let {}_value = {}.unsigned_integer()?;", + config.var_name, deserializer_name + )); + } + let mut compare_block = + Block::new(format!("if {}_value != {}", config.var_name, x)); + compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Uint({}_value), expected: Key::Uint({}) }}.into());", config.var_name, x)); + deser_code.content.push_block(compare_block); + if cli.preserve_encodings { + config + .final_exprs + .push(format!("Some({}_encoding)", config.var_name)); + deser_code.content.line(&format!( + "{}{}{}", + before_after.before_str(false), + final_expr(config.final_exprs, None), + before_after.after_str(false) + )); + //body.line(&format!("{}{}{}_encoding{}{}", before, sp, var_name, ep, after)); + } } - let mut compare_block = - Block::new(format!("if {}_value != \"{}\"", config.var_name, x)); - compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Str({}_value), expected: Key::Str(String::from(\"{}\")) }}.into());", config.var_name, x)); - deser_code.content.push_block(compare_block); - if cli.preserve_encodings { - config.final_exprs.push(format!( - "StringEncoding::from({}_encoding)", - config.var_name - )); - deser_code.content.line(&format!( - "{}{}{}", - before_after.before_str(false), - final_expr(config.final_exprs, None), - before_after.after_str(false) - )); + FixedValue::Nint(x) => { + if cli.preserve_encodings { + deser_code.content.line(&format!( + "let ({}_value, {}_encoding) = {}.negative_integer_sz()?;", + config.var_name, config.var_name, deserializer_name + )); + } else { + // we use the _sz variant here too to get around imcomplete nint support in the regular negative_integer() + deser_code.content.line(&format!( + "let ({}_value, _) = {}.negative_integer_sz()?;", + config.var_name, deserializer_name + )); + } + let x_abs = (x + 1).abs(); + let mut compare_block = + Block::new(format!("if {}_value != {}", config.var_name, x)); + compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Uint(({}_value + 1).abs() as u64), expected: Key::Uint({}) }}.into());", config.var_name, x_abs)); + deser_code.content.push_block(compare_block); + if cli.preserve_encodings { + config + .final_exprs + .push(format!("Some({}_encoding)", config.var_name)); + deser_code.content.line(&format!( + "{}{}{}", + before_after.before_str(false), + final_expr(config.final_exprs, None), + before_after.after_str(false) + )); + //body.line(&format!("{}{}{}_encoding{}{}", before, sp, var_name, ep, after)); + } } - } - FixedValue::Float(x) => { - deser_code.content.line(&format!( - "let {}_value = {}.float()?;", - config.var_name, deserializer_name - )); - let mut compare_block = - Block::new(format!("if {}_value != {}", config.var_name, x)); - compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Float({}_value), expected: Key::Float({}) }}.into());", config.var_name, x)); - deser_code.content.push_block(compare_block); - if cli.preserve_encodings { - unimplemented!("preserve_encodings is not implemented for float") + FixedValue::Text(x) => { + if cli.preserve_encodings { + deser_code.content.line(&format!( + "let ({}_value, {}_encoding) = {}.text_sz()?;", + config.var_name, config.var_name, deserializer_name + )); + } else { + deser_code.content.line(&format!( + "let {}_value = {}.text()?;", + config.var_name, deserializer_name + )); + } + let mut compare_block = + Block::new(format!("if {}_value != \"{}\"", config.var_name, x)); + compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Str({}_value), expected: Key::Str(String::from(\"{}\")) }}.into());", config.var_name, x)); + deser_code.content.push_block(compare_block); + if cli.preserve_encodings { + config.final_exprs.push(format!( + "StringEncoding::from({}_encoding)", + config.var_name + )); + deser_code.content.line(&format!( + "{}{}{}", + before_after.before_str(false), + final_expr(config.final_exprs, None), + before_after.after_str(false) + )); + } } - } - _ => unimplemented!(), - }; - deser_code.throws = true; - // this block needs to evaluate to a Result even though it has no value - if !cli.preserve_encodings && before_after.expects_result { - deser_code.content.line("Ok(())"); - } - } - SerializingRustType::Root(ConceptualRustType::Primitive(p), type_cfg) => { - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.read_len_used = true; - deser_code.throws = true; - } - let error_convert = if before_after.expects_result { - convert_err_to_ours - } else { - "" - }; - let non_preserve_bounds_fn = - |x: &str, bounds: &Option<(Option, Option)>| match bounds { - // always convert error to have consistent E for the and_then - Some(bounds) => Cow::Owned(format!( - "{}.and_then(|{}| {} else {{ Ok({}) }})", - convert_err_to_ours, - x, - bounds_check_if_block(bounds, &bounds_check_expr(*p, x), false), - x, - )), - None => Cow::Borrowed(""), - }; - let mut deser_primitive = - |mut final_exprs: Vec, func: &str, x: &str, x_expr: &str| { - if cli.preserve_encodings { - let enc_expr = match func { - "text" | "bytes" => "StringEncoding::from(enc)", - _ => "Some(enc)", - }; - final_exprs.push(enc_expr.to_owned()); - let enc_map_fn = match &type_cfg.bounds { - // always convert error to have consistent E for the and_then - Some(bounds) => format!( - "{}.and_then(|({}, enc)| {} else {{ Ok({}) }})", - convert_err_to_ours, - x, - bounds_check_if_block(bounds, &bounds_check_expr(*p, x), false), - final_expr(final_exprs, Some(x_expr.to_owned())), - ), - None => format!( - ".map(|({}, enc)| {})", - x, - final_expr(final_exprs, Some(x_expr.to_owned())) - ), - }; + FixedValue::Float(x) => { deser_code.content.line(&format!( - "{}{}.{}_sz(){}{}{}", - before_after.before_str(true), - deserializer_name, - func, - error_convert, - enc_map_fn, - before_after.after_str(true) - )); - } else { - deser_code.content.line(&format!( - "{}{}.{}(){}? as {}{}", - before_after.before_str(false), - deserializer_name, - func, - non_preserve_bounds_fn(x, &type_cfg.bounds), - p, - before_after.after_str(false) + "let {}_value = {}.float()?;", + config.var_name, deserializer_name )); - deser_code.throws = true; + let mut compare_block = + Block::new(format!("if {}_value != {}", config.var_name, x)); + compare_block.line(format!("return Err(DeserializeFailure::FixedValueMismatch{{ found: Key::Float({}_value), expected: Key::Float({}) }}.into());", config.var_name, x)); + deser_code.content.push_block(compare_block); + if cli.preserve_encodings { + unimplemented!("preserve_encodings is not implemented for float") + } } + _ => unimplemented!(), }; - match p { - Primitive::Bytes => { - deser_primitive(config.final_exprs, "bytes", "bytes", "bytes") + deser_code.throws = true; + // this block needs to evaluate to a Result even though it has no value + if !cli.preserve_encodings && before_after.expects_result { + deser_code.content.line("Ok(())"); } - Primitive::U8 | Primitive::U16 | Primitive::U32 => deser_primitive( - config.final_exprs, - "unsigned_integer", - "x", - &format!("x as {}", p), - ), - Primitive::U64 => { - deser_primitive(config.final_exprs, "unsigned_integer", "x", "x") + } + SerializingRustType::Root(ConceptualRustType::Primitive(p), type_cfg) => { + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.read_len_used = true; + deser_code.throws = true; } - Primitive::I8 | Primitive::I16 | Primitive::I32 | Primitive::I64 => { - // we need to only look at poisitve or negative bounds to avoid comparing e.g. a u64 with a negative - let positive_bounds = type_cfg.bounds.map(|(lower, upper)| { - (lower.filter(|l| *l > 0), upper.filter(|u| *u > 0)) - }); - let negative_bounds = type_cfg.bounds.map(|(lower, upper)| { - (lower.filter(|l| *l < 0), upper.filter(|u| *u < 0)) - }); - let mut type_check = Block::new(format!( - "{}match {}.cbor_type()?", - before_after.before_str(false), - deserializer_name - )); - if cli.preserve_encodings { - let bounds_fn = - |bounds: &Option<(Option, Option)>| match bounds { - // always convert error to have consistent E for the and_then - Some(bounds) => Cow::Owned(format!( - "{}.and_then(|(x, enc)| {} else {{ Ok((x, enc)) }})", - convert_err_to_ours, + let error_convert = if before_after.expects_result { + convert_err_to_ours + } else { + "" + }; + let non_preserve_bounds_fn = + |x: &str, bounds: &Option<(Option, Option)>| match bounds { + // always convert error to have consistent E for the and_then + Some(bounds) => Cow::Owned(format!( + "{}.and_then(|{}| {} else {{ Ok({}) }})", + convert_err_to_ours, + x, + bounds_check_if_block(bounds, &bounds_check_expr(*p, x), false), + x, + )), + None => Cow::Borrowed(""), + }; + let mut deser_primitive = + |mut final_exprs: Vec, func: &str, x: &str, x_expr: &str| { + if cli.preserve_encodings { + let enc_expr = match func { + "text" | "bytes" => "StringEncoding::from(enc)", + _ => "Some(enc)", + }; + final_exprs.push(enc_expr.to_owned()); + let enc_map_fn = match &type_cfg.bounds { + // always convert error to have consistent E for the and_then + Some(bounds) => format!( + "{}.and_then(|({}, enc)| {} else {{ Ok({}) }})", + convert_err_to_ours, + x, bounds_check_if_block( bounds, - &bounds_check_expr(*p, "x"), + &bounds_check_expr(*p, x), false ), - )), - None => Cow::Borrowed(""), + final_expr(final_exprs, Some(x_expr.to_owned())), + ), + None => format!( + ".map(|({}, enc)| {})", + x, + final_expr(final_exprs, Some(x_expr.to_owned())) + ), }; - let mut pos = Block::new("cbor_event::Type::UnsignedInteger =>"); - pos.line(&format!( - "let (x, enc) = {}.unsigned_integer_sz(){}?;", - deserializer_name, - bounds_fn(&positive_bounds) - )) - .line(format!("(x as {}, Some(enc))", p)) - .after(","); - type_check.push_block(pos); - // let this cover both the negative int case + error case - let mut neg = Block::new("_ =>"); - neg.line(&format!( - "let (x, enc) = {}.negative_integer_sz(){}?;", - deserializer_name, - bounds_fn(&negative_bounds) - )) - .line(format!("(x as {}, Some(enc))", p)) - .after(","); - type_check.push_block(neg); - } else { - type_check + deser_code.content.line(&format!( + "{}{}.{}_sz(){}{}{}", + before_after.before_str(true), + deserializer_name, + func, + error_convert, + enc_map_fn, + before_after.after_str(true) + )); + } else { + deser_code.content.line(&format!( + "{}{}.{}(){}? as {}{}", + before_after.before_str(false), + deserializer_name, + func, + non_preserve_bounds_fn(x, &type_cfg.bounds), + p, + before_after.after_str(false) + )); + deser_code.throws = true; + } + }; + match p { + Primitive::Bytes => { + deser_primitive(config.final_exprs, "bytes", "bytes", "bytes") + } + Primitive::U8 | Primitive::U16 | Primitive::U32 => deser_primitive( + config.final_exprs, + "unsigned_integer", + "x", + &format!("x as {}", p), + ), + Primitive::U64 => { + deser_primitive(config.final_exprs, "unsigned_integer", "x", "x") + } + Primitive::I8 | Primitive::I16 | Primitive::I32 | Primitive::I64 => { + // we need to only look at poisitve or negative bounds to avoid comparing e.g. a u64 with a negative + let positive_bounds = type_cfg.bounds.map(|(lower, upper)| { + (lower.filter(|l| *l > 0), upper.filter(|u| *u > 0)) + }); + let negative_bounds = type_cfg.bounds.map(|(lower, upper)| { + (lower.filter(|l| *l < 0), upper.filter(|u| *u < 0)) + }); + let mut type_check = Block::new(format!( + "{}match {}.cbor_type()?", + before_after.before_str(false), + deserializer_name + )); + if cli.preserve_encodings { + let bounds_fn = + |bounds: &Option<(Option, Option)>| match bounds { + // always convert error to have consistent E for the and_then + Some(bounds) => Cow::Owned(format!( + "{}.and_then(|(x, enc)| {} else {{ Ok((x, enc)) }})", + convert_err_to_ours, + bounds_check_if_block( + bounds, + &bounds_check_expr(*p, "x"), + false + ), + )), + None => Cow::Borrowed(""), + }; + let mut pos = Block::new("cbor_event::Type::UnsignedInteger =>"); + pos.line(&format!( + "let (x, enc) = {}.unsigned_integer_sz(){}?;", + deserializer_name, + bounds_fn(&positive_bounds) + )) + .line(format!("(x as {}, Some(enc))", p)) + .after(","); + type_check.push_block(pos); + // let this cover both the negative int case + error case + let mut neg = Block::new("_ =>"); + neg.line(&format!( + "let (x, enc) = {}.negative_integer_sz(){}?;", + deserializer_name, + bounds_fn(&negative_bounds) + )) + .line(format!("(x as {}, Some(enc))", p)) + .after(","); + type_check.push_block(neg); + } else { + type_check .line(format!( "cbor_event::Type::UnsignedInteger => {}.unsigned_integer(){}? as {},", deserializer_name, non_preserve_bounds_fn("x", &positive_bounds), p)); - // https://github.com/primetype/cbor_event/issues/9 - // cbor_event's negative_integer() doesn't support i64::MIN so we use the _sz function here instead as that one supports all nints - if *p == Primitive::I64 { - let bounds_fn = match &type_cfg.bounds { - Some(bounds) => Cow::Owned(format!( - "{}.and_then(|(x, _enc)| {} else {{ Ok((x, _enc)) }})", - convert_err_to_ours, - bounds_check_if_block( - bounds, - &bounds_check_expr(*p, "x"), - false - ), - )), - None => Cow::Borrowed(""), - }; - type_check.line(format!( + // https://github.com/primetype/cbor_event/issues/9 + // cbor_event's negative_integer() doesn't support i64::MIN so we use the _sz function here instead as that one supports all nints + if *p == Primitive::I64 { + let bounds_fn = match &type_cfg.bounds { + Some(bounds) => Cow::Owned(format!( + "{}.and_then(|(x, _enc)| {} else {{ Ok((x, _enc)) }})", + convert_err_to_ours, + bounds_check_if_block( + bounds, + &bounds_check_expr(*p, "x"), + false + ), + )), + None => Cow::Borrowed(""), + }; + type_check.line(format!( "_ => {}.negative_integer_sz(){}.map(|(x, _enc)| x)? as {},", deserializer_name, bounds_fn, p )); - } else { - type_check.line(format!( - "_ => {}.negative_integer(){}? as {},", - deserializer_name, - non_preserve_bounds_fn("x", &negative_bounds), - p - )); + } else { + type_check.line(format!( + "_ => {}.negative_integer(){}? as {},", + deserializer_name, + non_preserve_bounds_fn("x", &negative_bounds), + p + )); + } } + type_check.after(&before_after.after_str(false)); + deser_code.content.push_block(type_check); + deser_code.throws = true; } - type_check.after(&before_after.after_str(false)); - deser_code.content.push_block(type_check); - deser_code.throws = true; - } - Primitive::N64 => { - if cli.preserve_encodings { - deser_primitive( - config.final_exprs, - "negative_integer", - "x", - "(x + 1).abs() as u64", - ) - } else { - // https://github.com/primetype/cbor_event/issues/9 - // cbor_event's negative_integer() doesn't support full nint range so we use the _sz function here instead as that one supports all nints - let bounds_fn = match &type_cfg.bounds { - Some(bounds) => Cow::Owned(format!( + Primitive::N64 => { + if cli.preserve_encodings { + deser_primitive( + config.final_exprs, + "negative_integer", + "x", + "(x + 1).abs() as u64", + ) + } else { + // https://github.com/primetype/cbor_event/issues/9 + // cbor_event's negative_integer() doesn't support full nint range so we use the _sz function here instead as that one supports all nints + let bounds_fn = match &type_cfg.bounds { + Some(bounds) => Cow::Owned(format!( ".and_then(|(x, _enc)| {} else {{ Ok((x + 1).abs() as u64) }})", bounds_check_if_block( bounds, @@ -2330,628 +2466,445 @@ impl GenerationScope { false ), )), - None => Cow::Borrowed(".map(|(x, _enc)| (x + 1).abs() as u64)"), - }; - deser_code.content.line(&format!( - "{}{}.negative_integer_sz(){}{}{}", - before_after.before_str(true), - deserializer_name, - error_convert, - bounds_fn, - before_after.after_str(true) + None => Cow::Borrowed(".map(|(x, _enc)| (x + 1).abs() as u64)"), + }; + deser_code.content.line(&format!( + "{}{}.negative_integer_sz(){}{}{}", + before_after.before_str(true), + deserializer_name, + error_convert, + bounds_fn, + before_after.after_str(true) + )); + } + } + Primitive::Str => deser_primitive(config.final_exprs, "text", "s", "s"), + Primitive::Bool => { + // no encoding differences for bool + deser_code.content.line(&final_result_expr_complete( + &mut deser_code.throws, + config.final_exprs, + "raw.bool().map_err(Into::into)", )); } - } - Primitive::Str => deser_primitive(config.final_exprs, "text", "s", "s"), - Primitive::Bool => { - // no encoding differences for bool - deser_code.content.line(&final_result_expr_complete( - &mut deser_code.throws, - config.final_exprs, - "raw.bool().map_err(Into::into)", - )); - } - Primitive::F32 => { - deser_code.content.line(&final_result_expr_complete( - &mut deser_code.throws, - config.final_exprs, - "f32::deserialize(raw)", - )); - if cli.preserve_encodings { - unimplemented!("preserve_encodings is not implemented for float") + Primitive::F32 => { + deser_code.content.line(&final_result_expr_complete( + &mut deser_code.throws, + config.final_exprs, + "f32::deserialize(raw)", + )); + if cli.preserve_encodings { + unimplemented!("preserve_encodings is not implemented for float") + } + if type_cfg.bounds.is_some() { + unimplemented!("bounds not supported for floats") + } } - if type_cfg.bounds.is_some() { - unimplemented!("bounds not supported for floats") + Primitive::F64 => { + deser_code.content.line(&final_result_expr_complete( + &mut deser_code.throws, + config.final_exprs, + "f64::deserialize(raw)", + )); + if cli.preserve_encodings { + unimplemented!("preserve_encodings is not implemented for float") + } + if type_cfg.bounds.is_some() { + unimplemented!("bounds not supported for floats") + } } - } - Primitive::F64 => { + }; + } + SerializingRustType::Root(ConceptualRustType::Rust(ident), type_cfg) => { + // check for type-level @custom_deserialize + if let Some(custom_deserialize) = &types + .rust_struct(ident) + .unwrap() + .config() + .custom_deserialize + { + // because this is type-level we must handle final_exprs as it could be wrapped in a tag, etc deser_code.content.line(&final_result_expr_complete( &mut deser_code.throws, config.final_exprs, - "f64::deserialize(raw)", + &format!("{}({})", custom_deserialize, deserializer_name), )); - if cli.preserve_encodings { - unimplemented!("preserve_encodings is not implemented for float") - } - if type_cfg.bounds.is_some() { - unimplemented!("bounds not supported for floats") - } - } - }; - } - SerializingRustType::Root(ConceptualRustType::Rust(ident), type_cfg) => { - match &types.rust_struct(ident).unwrap().variant() { - RustStructType::CStyleEnum { variants } => { - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.throws = true; - deser_code.read_len_used = true; - } - // iflet Some(common) = enum_variants_common_constant_type(variants) { - // // TODO: potentially simplified deserialization some day - // // issue: https://github.com/dcSpark/cddl-codegen/issues/145 - // } else { - deser_code.content.line( - "let initial_position = raw.as_mut_ref().stream_position().unwrap();", - ); - let mut variant_final_exprs = config.final_exprs.clone(); - if cli.preserve_encodings { - for enc_var in encoding_fields( - types, - config.var_name, - variants[0].rust_type(), - false, - cli, - ) { - variant_final_exprs.push(enc_var.field_name); - } - } - for variant in variants { - let mut return_if_deserialized = - make_enum_variant_return_if_deserialized( - self, - types, - variant, - variant_final_exprs.is_empty(), - None, - &mut deser_code.content, - cli, + } else { + match &types.rust_struct(ident).unwrap().variant() { + RustStructType::CStyleEnum { variants } => { + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.throws = true; + deser_code.read_len_used = true; + } + // iflet Some(common) = enum_variants_common_constant_type(variants) { + // // TODO: potentially simplified deserialization some day + // // issue: https://github.com/dcSpark/cddl-codegen/issues/145 + // } else { + deser_code.content.line( + "let initial_position = raw.as_mut_ref().stream_position().unwrap();", ); - return_if_deserialized + let mut variant_final_exprs = config.final_exprs.clone(); + if cli.preserve_encodings { + for enc_var in encoding_fields( + types, + config.var_name, + variants[0].rust_type(), + false, + cli, + ) { + variant_final_exprs.push(enc_var.field_name); + } + } + for variant in variants { + let mut return_if_deserialized = + make_enum_variant_return_if_deserialized( + self, + types, + variant, + variant_final_exprs.is_empty(), + None, + &mut deser_code.content, + cli, + ); + return_if_deserialized .line(format!("Ok(({})) => return Ok({}),", variant_final_exprs.join(", "), final_expr(variant_final_exprs.clone(), Some(format!("{}::{}", ident, variant.name))))) .line("Err(_) => raw.as_mut_ref().seek(SeekFrom::Start(initial_position)).unwrap(),") .after(";"); - deser_code.content.push_block(return_if_deserialized); - } - deser_code.content.line(&format!( + deser_code.content.push_block(return_if_deserialized); + } + deser_code.content.line(&format!( "Err(DeserializeError::new(\"{ident}\", DeserializeFailure::NoVariantMatched))" )); - } - RustStructType::RawBytesType => { - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.throws = true; - deser_code.read_len_used = true; - } - if cli.preserve_encodings { - config - .final_exprs - .push("StringEncoding::from(enc)".to_owned()); - let from_raw_bytes_with_conversions = format!( + } + RustStructType::RawBytesType => { + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.throws = true; + deser_code.read_len_used = true; + } + if cli.preserve_encodings { + config + .final_exprs + .push("StringEncoding::from(enc)".to_owned()); + let from_raw_bytes_with_conversions = format!( "{}::from_raw_bytes(&bytes).map(|bytes| {}).map_err(|e| DeserializeFailure::InvalidStructure(Box::new(e)).into())", ident, final_expr(config.final_exprs, Some("bytes".to_owned())) ); - deser_code.content.line(&format!( - "{}{}.bytes_sz(){}.and_then(|(bytes, enc)| {}){}", - before_after.before_str(true), - deserializer_name, - convert_err_to_ours, - from_raw_bytes_with_conversions, - before_after.after_str(true) - )); - } else { - let from_raw_bytes_with_conversions = format!( + deser_code.content.line(&format!( + "{}{}.bytes_sz(){}.and_then(|(bytes, enc)| {}){}", + before_after.before_str(true), + deserializer_name, + convert_err_to_ours, + from_raw_bytes_with_conversions, + before_after.after_str(true) + )); + } else { + let from_raw_bytes_with_conversions = format!( "{ident}::from_raw_bytes(&bytes).map_err(|e| DeserializeFailure::InvalidStructure(Box::new(e)).into())" ); - deser_code.content.line(&format!( - "{}{}.bytes(){}.and_then(|bytes| {}){}", - before_after.before_str(true), - deserializer_name, - convert_err_to_ours, - from_raw_bytes_with_conversions, - before_after.after_str(true) - )); - } - } - _ => { - if types.is_plain_group(ident) && !type_cfg.basic_override { - // This would mess up with length checks otherwise and is probably not a likely situation if this is even valid in CDDL. - // To have this work (if it's valid) you'd either need to generate 2 embedded deserialize methods or pass - // a parameter whether it was an optional field, and if so, read_len.read_elems(embedded mandatory fields)?; - // since otherwise it'd only length check the optional fields within the type. - assert!(!config.optional_field); - deser_code.read_len_used = true; - let final_expr_value = format!( - "{}::deserialize_as_embedded_group({}, {}, len)", - ident, - deserializer_name, - config.pass_read_len() - ); + deser_code.content.line(&format!( + "{}{}.bytes(){}.and_then(|bytes| {}){}", + before_after.before_str(true), + deserializer_name, + convert_err_to_ours, + from_raw_bytes_with_conversions, + before_after.after_str(true) + )); + } + } + _ => { + if types.is_plain_group(ident) && !type_cfg.basic_override { + // This would mess up with length checks otherwise and is probably not a likely situation if this is even valid in CDDL. + // To have this work (if it's valid) you'd either need to generate 2 embedded deserialize methods or pass + // a parameter whether it was an optional field, and if so, read_len.read_elems(embedded mandatory fields)?; + // since otherwise it'd only length check the optional fields within the type. + assert!(!config.optional_field); + deser_code.read_len_used = true; + let final_expr_value = format!( + "{}::deserialize_as_embedded_group({}, {}, len)", + ident, + deserializer_name, + config.pass_read_len() + ); - deser_code.content.line(&final_result_expr_complete( - &mut deser_code.throws, - config.final_exprs, - &final_expr_value, - )); - } else { - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.read_len_used = true; - deser_code.throws = true; + deser_code.content.line(&final_result_expr_complete( + &mut deser_code.throws, + config.final_exprs, + &final_expr_value, + )); + } else { + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.read_len_used = true; + deser_code.throws = true; + } + let final_expr_value = + format!("{ident}::deserialize({deserializer_name})"); + deser_code.content.line(&final_result_expr_complete( + &mut deser_code.throws, + config.final_exprs, + &final_expr_value, + )); + } } - let final_expr_value = - format!("{ident}::deserialize({deserializer_name})"); - deser_code.content.line(&final_result_expr_complete( - &mut deser_code.throws, - config.final_exprs, - &final_expr_value, - )); } } } - } - SerializingRustType::Root(ConceptualRustType::Optional(ty), _cfg) => { - let read_len_check = - config.optional_field || (ty.expanded_field_count(types) != Some(1)); - // codegen crate doesn't support if/else or appending a block after a block, only strings - // so we need to create a local bool var and use a match instead - let if_label = if ty.cbor_types(types).contains(&cbor_event::Type::Special) { - let is_some_check_var = format!("{}_is_some", config.var_name); - let mut is_some_check = - Block::new(format!("let {is_some_check_var} = match cbor_type()?")); - let mut special_block = Block::new("cbor_event::Type::Special =>"); - special_block.line(&format!("let special = {deserializer_name}.special()?;")); - special_block.line(&format!( + SerializingRustType::Root(ConceptualRustType::Optional(ty), _cfg) => { + let read_len_check = + config.optional_field || (ty.expanded_field_count(types) != Some(1)); + // codegen crate doesn't support if/else or appending a block after a block, only strings + // so we need to create a local bool var and use a match instead + let if_label = if ty.cbor_types(types).contains(&cbor_event::Type::Special) { + let is_some_check_var = format!("{}_is_some", config.var_name); + let mut is_some_check = + Block::new(format!("let {is_some_check_var} = match cbor_type()?")); + let mut special_block = Block::new("cbor_event::Type::Special =>"); + special_block + .line(&format!("let special = {deserializer_name}.special()?;")); + special_block.line(&format!( "{deserializer_name}.as_mut_ref().seek(SeekFrom::Current(-1)).unwrap();" )); - let mut special_match = Block::new("match special"); - // TODO: we need to check that we don't have null / null somewhere - special_match.line("cbor_event::Special::Null => false,"); - // no need to error check - would happen in generated deserialize code - special_match.line("_ => true,"); - special_block.push_block(special_match); - special_block.after(","); - is_some_check.push_block(special_block); - // it's possible the Some case only has Special as its starting tag(s), - // but we don't care since it'll fail in either either case anyway, - // and would give a good enough error (ie expected Special::X but found non-Special) - is_some_check.line("_ => true,"); - is_some_check.after(";"); - deser_code.content.push_block(is_some_check); - is_some_check_var - } else { - String::from(&format!( - "{deserializer_name}.cbor_type()? != cbor_event::Type::Special" - )) - }; - let mut deser_block = Block::new(format!( - "{}match {}", - before_after.before_str(false), - if_label - )); - let mut some_block = Block::new("true =>"); - if read_len_check { - let mandatory_fields = ty.expanded_mandatory_field_count(types); - if mandatory_fields != 0 { - some_block.line(format!("read_len.read_elems({mandatory_fields})?;")); - deser_code.read_len_used = true; - } - } - let ty_enc_fields = if cli.preserve_encodings { - encoding_fields( - types, - config.var_name, - &ty.clone().resolve_aliases(), - false, - cli, - ) - } else { - vec![] - }; - if ty_enc_fields.is_empty() { - self.generate_deserialize( - types, - (&**ty).into(), - DeserializeBeforeAfter::new("Some(", ")", false), - config.optional_field(false), - cli, - ) - .add_to(&mut some_block); - } else { - let (map_some_before, map_some_after) = if ty.is_fixed_value() { - // case 1: no actual return, only encoding values for tags/fixed values, no need to wrap in Some() - ("", "".to_owned()) - } else { - // case 2: need to map FIRST element in Some(x) - let enc_vars_str = ty_enc_fields - .iter() - .map(|enc_field| enc_field.field_name.clone()) - .collect::>() - .join(", "); - // we need to annotate the Ok's error type since the compiler gets confused otherwise - ( - "Result::<_, DeserializeError>::Ok(", - format!(").map(|(x, {enc_vars_str})| (Some(x), {enc_vars_str}))?"), - ) - }; - self.generate_deserialize( - types, - (&**ty).into(), - DeserializeBeforeAfter::new(map_some_before, &map_some_after, false), - config.optional_field(false), - cli, - ) - .add_to(&mut some_block); - } - some_block.after(","); - deser_block.push_block(some_block); - let mut none_block = Block::new("false =>"); - if read_len_check { - none_block.line("read_len.read_elems(1)?;"); - deser_code.read_len_used = true; - } - // we don't use this to avoid the new (true) if cli.preserve_encodings is set - //self.generate_deserialize(types, &ConceptualRustType::Fixed(FixedValue::Null), var_name, "", "", in_embedded, false, add_parens, &mut none_block); - let mut check_null = Block::new(format!( - "if {deserializer_name}.special()? != cbor_event::Special::Null" - )); - check_null.line("return Err(DeserializeFailure::ExpectedNull.into());"); - none_block.push_block(check_null); - if cli.preserve_encodings { - let mut none_elems = if ty.is_fixed_value() { - vec![] + let mut special_match = Block::new("match special"); + // TODO: we need to check that we don't have null / null somewhere + special_match.line("cbor_event::Special::Null => false,"); + // no need to error check - would happen in generated deserialize code + special_match.line("_ => true,"); + special_block.push_block(special_match); + special_block.after(","); + is_some_check.push_block(special_block); + // it's possible the Some case only has Special as its starting tag(s), + // but we don't care since it'll fail in either either case anyway, + // and would give a good enough error (ie expected Special::X but found non-Special) + is_some_check.line("_ => true,"); + is_some_check.after(";"); + deser_code.content.push_block(is_some_check); + is_some_check_var } else { - vec!["None".to_owned()] - }; - none_elems.extend( - ty_enc_fields - .iter() - .map(|enc_field| enc_field.default_expr.to_owned()), - ); - match none_elems.len() { - // this probably isn't properly supported by other parts of code and is so unlikely to be encountered - // that we really don't care right now. if you run into this open an issue and it can be investigated - 0 => unimplemented!("please open a github issue"), - 1 => none_block.line(none_elems.first().unwrap()), - _ => none_block.line(format!("({})", none_elems.join(", "))), + String::from(&format!( + "{deserializer_name}.cbor_type()? != cbor_event::Type::Special" + )) }; - } else { - none_block.line("None"); - } - deser_block.after(&before_after.after_str(false)); - deser_block.push_block(none_block); - deser_code.content.push_block(deser_block); - deser_code.throws = true; - } - SerializingRustType::Root(ConceptualRustType::Array(ty), type_cfg) => { - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.read_len_used = true; - } - let arr_var_name = format!("{}_arr", config.var_name); - deser_code - .content - .line(&format!("let mut {arr_var_name} = Vec::new();")); - let elem_var_name = format!("{}_elem", config.var_name); - let elem_encs = if cli.preserve_encodings { - encoding_fields( - types, - &elem_var_name, - &ty.clone().resolve_aliases(), - false, - cli, - ) - } else { - vec![] - }; - if cli.preserve_encodings { - deser_code - .content - .line(&format!("let len = {deserializer_name}.array_sz()?;")) - .line(&format!("let {}_encoding = len.into();", config.var_name)); - if !elem_encs.is_empty() { - deser_code.content.line(&format!( - "let mut {}_elem_encodings = Vec::new();", - config.var_name - )); - } - } else { - deser_code - .content - .line(&format!("let len = {deserializer_name}.array()?;")); - } - let mut elem_config = DeserializeConfig::new(&elem_var_name); - let (mut deser_loop, plain_len_check) = match &ty.conceptual_type { - ConceptualRustType::Rust(ty_ident) if types.is_plain_group(ty_ident) => { - // two things that must be done differently for embedded plain groups: - // 1) We can't directly read the CBOR len's number of items since it could be >1 - // 2) We need a different cbor read len var to pass into embedded deserialize - let read_len_overload = format!("{}_read_len", config.var_name); - deser_code.content.line(&format!( - "let mut {read_len_overload} = CBORReadLen::new(len);" - )); - // inside of deserialize_as_embedded_group we only modify read_len for things we couldn't - // statically know beforehand. This was done for other areas that use plain groups in order - // to be able to do static length checks for statically sized groups that contain plain groups - // at the start of deserialization instead of many checks for every single field. - let plain_len_check = match ty.expanded_mandatory_field_count(types) { - 0 => None, - n => Some(format!("{read_len_overload}.read_elems({n})?;")), - }; - elem_config = elem_config.overload_read_len(read_len_overload); - let deser_loop = make_deser_loop( - "len", - &format!("{}_read_len.read()", config.var_name), - cli, - ); - (deser_loop, plain_len_check) - } - _ => ( - make_deser_loop("len", &format!("({arr_var_name}.len() as u64)"), cli), - None, - ), - }; - deser_loop.push_block(make_deser_loop_break_check()); - if let Some(plain_len_check) = plain_len_check { - deser_loop.line(plain_len_check); - } - elem_config.deserializer_name_overload = config.deserializer_name_overload; - if !elem_encs.is_empty() { - let elem_var_names_str = encoding_var_names_str(types, &elem_var_name, ty, cli); - self.generate_deserialize( - types, - (&**ty).into(), - DeserializeBeforeAfter::new( - &format!("let {elem_var_names_str} = "), - ";", - false, - ), - elem_config, - cli, - ) - .add_to(&mut deser_loop); - deser_loop - .line(format!("{arr_var_name}.push({elem_var_name});")) - .line(format!( - "{}_elem_encodings.push({});", - config.var_name, - tuple_str(elem_encs.iter().map(|enc| enc.field_name.clone()).collect()) - )); - } else { - self.generate_deserialize( - types, - (&**ty).into(), - DeserializeBeforeAfter::new(&format!("{arr_var_name}.push("), ");", false), - elem_config, - cli, - ) - .add_to(&mut deser_loop); - } - deser_code.content.push_block(deser_loop); - if let Some(bounds) = &type_cfg.bounds { - // we use cargo fmt after so it's okay if we just use .line() here - deser_code.content.line(&bounds_check_if_block( - bounds, - &format!("{arr_var_name}.len()"), - true, - )); - } - if cli.preserve_encodings { - config - .final_exprs - .push(format!("{}_encoding", config.var_name)); - if !elem_encs.is_empty() { - config - .final_exprs - .push(format!("{}_elem_encodings", config.var_name)); - } - deser_code.content.line(&format!( - "{}{}{}", - before_after.before_str(false), - final_expr(config.final_exprs, Some(arr_var_name)), - before_after.after_str(false) - )); - } else { - deser_code.content.line(&format!( - "{}{}{}", + let mut deser_block = Block::new(format!( + "{}match {}", before_after.before_str(false), - arr_var_name, - before_after.after_str(false) + if_label )); - } - deser_code.throws = true; - } - SerializingRustType::Root(ConceptualRustType::Map(key_type, value_type), type_cfg) => { - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.read_len_used = true; - } - if !self.deserialize_generated_for_type(types, &key_type.conceptual_type) { - todo!(); - // TODO: where is the best place to check for this? should we pass in a RustIdent to say where we're generating?! - //self.dont_generate_deserialize(name, format!("key type {} doesn't support deserialize", key_type.for_rust_member())); - } else if !self.deserialize_generated_for_type(types, &value_type.conceptual_type) { - todo!(); - //self.dont_generate_deserialize(name, format!("value type {} doesn't support deserialize", value_type.for_rust_member())); - } else { - let table_var = format!("{}_table", config.var_name); - deser_code.content.line(&format!( - "let mut {} = {}::new();", - table_var, - table_type(cli) - )); - let key_var_name = format!("{}_key", config.var_name); - let value_var_name = format!("{}_value", config.var_name); - let key_encs = if cli.preserve_encodings { + let mut some_block = Block::new("true =>"); + if read_len_check { + let mandatory_fields = ty.expanded_mandatory_field_count(types); + if mandatory_fields != 0 { + some_block.line(format!("read_len.read_elems({mandatory_fields})?;")); + deser_code.read_len_used = true; + } + } + let ty_enc_fields = if cli.preserve_encodings { encoding_fields( types, - &key_var_name, - &key_type.clone().resolve_aliases(), + config.var_name, + &ty.clone().resolve_aliases(), false, cli, ) } else { vec![] }; - let value_encs = if cli.preserve_encodings { + if ty_enc_fields.is_empty() { + self.generate_deserialize( + types, + (&**ty).into(), + DeserializeBeforeAfter::new("Some(", ")", false), + config.optional_field(false), + cli, + ) + .add_to(&mut some_block); + } else { + let (map_some_before, map_some_after) = if ty.is_fixed_value() { + // case 1: no actual return, only encoding values for tags/fixed values, no need to wrap in Some() + ("", "".to_owned()) + } else { + // case 2: need to map FIRST element in Some(x) + let enc_vars_str = ty_enc_fields + .iter() + .map(|enc_field| enc_field.field_name.clone()) + .collect::>() + .join(", "); + // we need to annotate the Ok's error type since the compiler gets confused otherwise + ( + "Result::<_, DeserializeError>::Ok(", + format!(").map(|(x, {enc_vars_str})| (Some(x), {enc_vars_str}))?"), + ) + }; + self.generate_deserialize( + types, + (&**ty).into(), + DeserializeBeforeAfter::new(map_some_before, &map_some_after, false), + config.optional_field(false), + cli, + ) + .add_to(&mut some_block); + } + some_block.after(","); + deser_block.push_block(some_block); + let mut none_block = Block::new("false =>"); + if read_len_check { + none_block.line("read_len.read_elems(1)?;"); + deser_code.read_len_used = true; + } + // we don't use this to avoid the new (true) if cli.preserve_encodings is set + //self.generate_deserialize(types, &ConceptualRustType::Fixed(FixedValue::Null), var_name, "", "", in_embedded, false, add_parens, &mut none_block); + let mut check_null = Block::new(format!( + "if {deserializer_name}.special()? != cbor_event::Special::Null" + )); + check_null.line("return Err(DeserializeFailure::ExpectedNull.into());"); + none_block.push_block(check_null); + if cli.preserve_encodings { + let mut none_elems = if ty.is_fixed_value() { + vec![] + } else { + vec!["None".to_owned()] + }; + none_elems.extend( + ty_enc_fields + .iter() + .map(|enc_field| enc_field.default_expr.to_owned()), + ); + match none_elems.len() { + // this probably isn't properly supported by other parts of code and is so unlikely to be encountered + // that we really don't care right now. if you run into this open an issue and it can be investigated + 0 => unimplemented!("please open a github issue"), + 1 => none_block.line(none_elems.first().unwrap()), + _ => none_block.line(format!("({})", none_elems.join(", "))), + }; + } else { + none_block.line("None"); + } + deser_block.after(&before_after.after_str(false)); + deser_block.push_block(none_block); + deser_code.content.push_block(deser_block); + deser_code.throws = true; + } + SerializingRustType::Root(ConceptualRustType::Array(ty), type_cfg) => { + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.read_len_used = true; + } + let arr_var_name = format!("{}_arr", config.var_name); + deser_code + .content + .line(&format!("let mut {arr_var_name} = Vec::new();")); + let elem_var_name = format!("{}_elem", config.var_name); + let elem_encs = if cli.preserve_encodings { encoding_fields( types, - &value_var_name, - &value_type.clone().resolve_aliases(), + &elem_var_name, + &ty.clone().resolve_aliases(), false, cli, ) } else { vec![] }; - let len_var = format!("{}_len", config.var_name); if cli.preserve_encodings { deser_code .content - .line(&format!("let {len_var} = {deserializer_name}.map_sz()?;")) - .line(&format!( - "let {}_encoding = {}.into();", - config.var_name, len_var - )); - if !key_encs.is_empty() { + .line(&format!("let len = {deserializer_name}.array_sz()?;")) + .line(&format!("let {}_encoding = len.into();", config.var_name)); + if !elem_encs.is_empty() { deser_code.content.line(&format!( - "let mut {}_key_encodings = BTreeMap::new();", - config.var_name - )); - } - if !value_encs.is_empty() { - deser_code.content.line(&format!( - "let mut {}_value_encodings = BTreeMap::new();", + "let mut {}_elem_encodings = Vec::new();", config.var_name )); } } else { deser_code .content - .line(&format!("let {len_var} = {deserializer_name}.map()?;")); + .line(&format!("let len = {deserializer_name}.array()?;")); } - let mut deser_loop = - make_deser_loop(&len_var, &format!("({table_var}.len() as u64)"), cli); - deser_loop.push_block(make_deser_loop_break_check()); - let mut key_config = DeserializeConfig::new(&key_var_name); - key_config.deserializer_name_overload = config.deserializer_name_overload; - let mut value_config = DeserializeConfig::new(&value_var_name); - value_config.deserializer_name_overload = config.deserializer_name_overload; - let (key_var_names_str, value_var_names_str) = if cli.preserve_encodings { - ( - encoding_var_names_str(types, &key_var_name, key_type, cli), - encoding_var_names_str(types, &value_var_name, value_type, cli), - ) - } else { - (key_var_name.clone(), value_var_name.clone()) - }; - self.generate_deserialize( - types, - (&**key_type).into(), - DeserializeBeforeAfter::new( - &format!("let {key_var_names_str} = "), - ";", - false, - ), - key_config, - cli, - ) - .add_to(&mut deser_loop); - self.generate_deserialize( - types, - (&**value_type).into(), - DeserializeBeforeAfter::new( - &format!("let {value_var_names_str} = "), - ";", - false, - ), - value_config, - cli, - ) - .add_to(&mut deser_loop); - let mut dup_check = Block::new(format!( - "if {}.insert({}{}, {}).is_some()", - table_var, - key_var_name, - if key_type.is_copy(types) { - "" - } else { - ".clone()" - }, - value_var_name - )); - let dup_key_error_key = match &key_type.conceptual_type { - ConceptualRustType::Primitive(Primitive::U8) - | ConceptualRustType::Primitive(Primitive::U16) - | ConceptualRustType::Primitive(Primitive::U32) - | ConceptualRustType::Primitive(Primitive::U64) => { - format!("Key::Uint({key_var_name}.into())") - } - ConceptualRustType::Primitive(Primitive::Str) => { - format!("Key::Str({key_var_name})") - } - // TODO: make a generic one then store serialized CBOR? - _ => "Key::Str(String::from(\"some complicated/unsupported type\"))" - .to_owned(), - }; - dup_check.line(format!( - "return Err(DeserializeFailure::DuplicateKey({dup_key_error_key}).into());" - )); - deser_loop.push_block(dup_check); - if cli.preserve_encodings { - if !key_encs.is_empty() { - deser_loop.line(format!( - "{}_key_encodings.insert({}{}, {});", - config.var_name, - key_var_name, - if key_type.encoding_var_is_copy(types) { - "" - } else { - ".clone()" - }, - tuple_str( - key_encs.iter().map(|enc| enc.field_name.clone()).collect() - ) + let mut elem_config = DeserializeConfig::new(&elem_var_name); + let (mut deser_loop, plain_len_check) = match &ty.conceptual_type { + ConceptualRustType::Rust(ty_ident) if types.is_plain_group(ty_ident) => { + // two things that must be done differently for embedded plain groups: + // 1) We can't directly read the CBOR len's number of items since it could be >1 + // 2) We need a different cbor read len var to pass into embedded deserialize + let read_len_overload = format!("{}_read_len", config.var_name); + deser_code.content.line(&format!( + "let mut {read_len_overload} = CBORReadLen::new(len);" )); + // inside of deserialize_as_embedded_group we only modify read_len for things we couldn't + // statically know beforehand. This was done for other areas that use plain groups in order + // to be able to do static length checks for statically sized groups that contain plain groups + // at the start of deserialization instead of many checks for every single field. + let plain_len_check = match ty.expanded_mandatory_field_count(types) { + 0 => None, + n => Some(format!("{read_len_overload}.read_elems({n})?;")), + }; + elem_config = elem_config.overload_read_len(read_len_overload); + let deser_loop = make_deser_loop( + "len", + &format!("{}_read_len.read()", config.var_name), + cli, + ); + (deser_loop, plain_len_check) } - if !value_encs.is_empty() { - deser_loop.line(format!( - "{}_value_encodings.insert({}{}, {});", + _ => ( + make_deser_loop("len", &format!("({arr_var_name}.len() as u64)"), cli), + None, + ), + }; + deser_loop.push_block(make_deser_loop_break_check()); + if let Some(plain_len_check) = plain_len_check { + deser_loop.line(plain_len_check); + } + elem_config.deserializer_name_overload = config.deserializer_name_overload; + if !elem_encs.is_empty() { + let elem_var_names_str = + encoding_var_names_str(types, &elem_var_name, ty, cli); + self.generate_deserialize( + types, + (&**ty).into(), + DeserializeBeforeAfter::new( + &format!("let {elem_var_names_str} = "), + ";", + false, + ), + elem_config, + cli, + ) + .add_to(&mut deser_loop); + deser_loop + .line(format!("{arr_var_name}.push({elem_var_name});")) + .line(format!( + "{}_elem_encodings.push({});", config.var_name, - key_var_name, - if key_type.encoding_var_is_copy(types) { - "" - } else { - ".clone()" - }, tuple_str( - value_encs - .iter() - .map(|enc| enc.field_name.clone()) - .collect() + elem_encs.iter().map(|enc| enc.field_name.clone()).collect() ) )); - } + } else { + self.generate_deserialize( + types, + (&**ty).into(), + DeserializeBeforeAfter::new( + &format!("{arr_var_name}.push("), + ");", + false, + ), + elem_config, + cli, + ) + .add_to(&mut deser_loop); } deser_code.content.push_block(deser_loop); if let Some(bounds) = &type_cfg.bounds { // we use cargo fmt after so it's okay if we just use .line() here deser_code.content.line(&bounds_check_if_block( bounds, - &format!("{table_var}.len()"), + &format!("{arr_var_name}.len()"), true, )); } @@ -2959,129 +2912,362 @@ impl GenerationScope { config .final_exprs .push(format!("{}_encoding", config.var_name)); - if !key_encs.is_empty() { + if !elem_encs.is_empty() { config .final_exprs - .push(format!("{}_key_encodings", config.var_name)); - } - if !value_encs.is_empty() { - config - .final_exprs - .push(format!("{}_value_encodings", config.var_name)); + .push(format!("{}_elem_encodings", config.var_name)); } deser_code.content.line(&format!( "{}{}{}", before_after.before_str(false), - final_expr(config.final_exprs, Some(table_var)), + final_expr(config.final_exprs, Some(arr_var_name)), before_after.after_str(false) )); } else { deser_code.content.line(&format!( "{}{}{}", before_after.before_str(false), - table_var, + arr_var_name, before_after.after_str(false) )); } + deser_code.throws = true; } - deser_code.throws = true; - } - SerializingRustType::Root(ConceptualRustType::Alias(_ident, ty), _cfg) => { - self.generate_deserialize(types, (&**ty).into(), before_after, config, cli) + SerializingRustType::Root( + ConceptualRustType::Map(key_type, value_type), + type_cfg, + ) => { + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.read_len_used = true; + } + if !self.deserialize_generated_for_type(types, &key_type.conceptual_type) { + todo!(); + // TODO: where is the best place to check for this? should we pass in a RustIdent to say where we're generating?! + //self.dont_generate_deserialize(name, format!("key type {} doesn't support deserialize", key_type.for_rust_member())); + } else if !self + .deserialize_generated_for_type(types, &value_type.conceptual_type) + { + todo!(); + //self.dont_generate_deserialize(name, format!("value type {} doesn't support deserialize", value_type.for_rust_member())); + } else { + let table_var = format!("{}_table", config.var_name); + deser_code.content.line(&format!( + "let mut {} = {}::new();", + table_var, + table_type(cli) + )); + let key_var_name = format!("{}_key", config.var_name); + let value_var_name = format!("{}_value", config.var_name); + let key_encs = if cli.preserve_encodings { + encoding_fields( + types, + &key_var_name, + &key_type.clone().resolve_aliases(), + false, + cli, + ) + } else { + vec![] + }; + let value_encs = if cli.preserve_encodings { + encoding_fields( + types, + &value_var_name, + &value_type.clone().resolve_aliases(), + false, + cli, + ) + } else { + vec![] + }; + let len_var = format!("{}_len", config.var_name); + if cli.preserve_encodings { + deser_code + .content + .line(&format!("let {len_var} = {deserializer_name}.map_sz()?;")) + .line(&format!( + "let {}_encoding = {}.into();", + config.var_name, len_var + )); + if !key_encs.is_empty() { + deser_code.content.line(&format!( + "let mut {}_key_encodings = BTreeMap::new();", + config.var_name + )); + } + if !value_encs.is_empty() { + deser_code.content.line(&format!( + "let mut {}_value_encodings = BTreeMap::new();", + config.var_name + )); + } + } else { + deser_code + .content + .line(&format!("let {len_var} = {deserializer_name}.map()?;")); + } + let mut deser_loop = + make_deser_loop(&len_var, &format!("({table_var}.len() as u64)"), cli); + deser_loop.push_block(make_deser_loop_break_check()); + let mut key_config = DeserializeConfig::new(&key_var_name); + key_config.deserializer_name_overload = config.deserializer_name_overload; + let mut value_config = DeserializeConfig::new(&value_var_name); + value_config.deserializer_name_overload = config.deserializer_name_overload; + let (key_var_names_str, value_var_names_str) = if cli.preserve_encodings { + ( + encoding_var_names_str(types, &key_var_name, key_type, cli), + encoding_var_names_str(types, &value_var_name, value_type, cli), + ) + } else { + (key_var_name.clone(), value_var_name.clone()) + }; + self.generate_deserialize( + types, + (&**key_type).into(), + DeserializeBeforeAfter::new( + &format!("let {key_var_names_str} = "), + ";", + false, + ), + key_config, + cli, + ) + .add_to(&mut deser_loop); + self.generate_deserialize( + types, + (&**value_type).into(), + DeserializeBeforeAfter::new( + &format!("let {value_var_names_str} = "), + ";", + false, + ), + value_config, + cli, + ) + .add_to(&mut deser_loop); + let mut dup_check = Block::new(format!( + "if {}.insert({}{}, {}).is_some()", + table_var, + key_var_name, + if key_type.is_copy(types) { + "" + } else { + ".clone()" + }, + value_var_name + )); + let dup_key_error_key = match &key_type.conceptual_type { + ConceptualRustType::Primitive(Primitive::U8) + | ConceptualRustType::Primitive(Primitive::U16) + | ConceptualRustType::Primitive(Primitive::U32) + | ConceptualRustType::Primitive(Primitive::U64) => { + format!("Key::Uint({key_var_name}.into())") + } + ConceptualRustType::Primitive(Primitive::Str) => { + format!("Key::Str({key_var_name})") + } + // TODO: make a generic one then store serialized CBOR? + _ => "Key::Str(String::from(\"some complicated/unsupported type\"))" + .to_owned(), + }; + dup_check.line(format!( + "return Err(DeserializeFailure::DuplicateKey({dup_key_error_key}).into());" + )); + deser_loop.push_block(dup_check); + if cli.preserve_encodings { + if !key_encs.is_empty() { + deser_loop.line(format!( + "{}_key_encodings.insert({}{}, {});", + config.var_name, + key_var_name, + if key_type.encoding_var_is_copy(types) { + "" + } else { + ".clone()" + }, + tuple_str( + key_encs.iter().map(|enc| enc.field_name.clone()).collect() + ) + )); + } + if !value_encs.is_empty() { + deser_loop.line(format!( + "{}_value_encodings.insert({}{}, {});", + config.var_name, + key_var_name, + if key_type.encoding_var_is_copy(types) { + "" + } else { + ".clone()" + }, + tuple_str( + value_encs + .iter() + .map(|enc| enc.field_name.clone()) + .collect() + ) + )); + } + } + deser_code.content.push_block(deser_loop); + if let Some(bounds) = &type_cfg.bounds { + // we use cargo fmt after so it's okay if we just use .line() here + deser_code.content.line(&bounds_check_if_block( + bounds, + &format!("{table_var}.len()"), + true, + )); + } + if cli.preserve_encodings { + config + .final_exprs + .push(format!("{}_encoding", config.var_name)); + if !key_encs.is_empty() { + config + .final_exprs + .push(format!("{}_key_encodings", config.var_name)); + } + if !value_encs.is_empty() { + config + .final_exprs + .push(format!("{}_value_encodings", config.var_name)); + } + deser_code.content.line(&format!( + "{}{}{}", + before_after.before_str(false), + final_expr(config.final_exprs, Some(table_var)), + before_after.after_str(false) + )); + } else { + deser_code.content.line(&format!( + "{}{}{}", + before_after.before_str(false), + table_var, + before_after.after_str(false) + )); + } + } + deser_code.throws = true; + } + SerializingRustType::Root(ConceptualRustType::Alias(ident, ty), _cfg) => { + let config_for_alias = if let Some(custom_deserialize) = types + .type_aliases() + .get(ident) + .unwrap() + .rule_metadata + .as_ref() + .and_then(|rmd| rmd.custom_deserialize.clone()) + { + config.custom_deserialize(custom_deserialize) + } else { + config + }; + self.generate_deserialize( + types, + (&**ty).into(), + before_after, + config_for_alias, + cli, + ) .add_to_code(&mut deser_code); - } - SerializingRustType::EncodingOperation(CBOREncodingOperation::CBORBytes, child) => { - if cli.preserve_encodings { - config.final_exprs.push(format!( - "StringEncoding::from({}_bytes_encoding)", - config.var_name - )); + } + SerializingRustType::EncodingOperation(CBOREncodingOperation::CBORBytes, child) => { + if cli.preserve_encodings { + config.final_exprs.push(format!( + "StringEncoding::from({}_bytes_encoding)", + config.var_name + )); + deser_code.content.line(&format!( + "let ({}_bytes, {}_bytes_encoding) = raw.bytes_sz()?;", + config.var_name, config.var_name + )); + } else { + deser_code + .content + .line(&format!("let {}_bytes = raw.bytes()?;", config.var_name)); + }; + let name_overload = "inner_de"; deser_code.content.line(&format!( - "let ({}_bytes, {}_bytes_encoding) = raw.bytes_sz()?;", - config.var_name, config.var_name + "let {} = &mut Deserializer::from(std::io::Cursor::new({}_bytes));", + name_overload, config.var_name )); - } else { - deser_code - .content - .line(&format!("let {}_bytes = raw.bytes()?;", config.var_name)); - }; - let name_overload = "inner_de"; - deser_code.content.line(&format!( - "let {} = &mut Deserializer::from(std::io::Cursor::new({}_bytes));", - name_overload, config.var_name - )); - self.generate_deserialize( - types, - *child, - before_after, - config.overload_deserializer(name_overload), - cli, - ) - .add_to_code(&mut deser_code); - deser_code.throws = true; - } - SerializingRustType::EncodingOperation(CBOREncodingOperation::Tagged(tag), child) => { - if config.optional_field { - deser_code.content.line("read_len.read_elems(1)?;"); - deser_code.read_len_used = true; + self.generate_deserialize( + types, + *child, + before_after, + config.overload_deserializer(name_overload), + cli, + ) + .add_to_code(&mut deser_code); + deser_code.throws = true; } - let mut tag_check = if cli.preserve_encodings { - let mut tag_check = Block::new(format!( - "{}match {}.tag_sz()?", - before_after.before, deserializer_name - )); - config.final_exprs.push("Some(tag_enc)".to_owned()); - let some_deser_code = self - .generate_deserialize( - types, - *child, - DeserializeBeforeAfter::new("", "", before_after.expects_result), - config.optional_field(false), - cli, - ) - .mark_and_extract_content(&mut deser_code); - if let Some(single_line) = some_deser_code.as_single_line() { - tag_check.line(format!("({tag}, tag_enc) => {single_line},")); - } else { - let mut deser_block = Block::new(format!("({tag}, tag_enc) =>")); - deser_block.push_all(some_deser_code); - deser_block.after(","); - tag_check.push_block(deser_block); + SerializingRustType::EncodingOperation( + CBOREncodingOperation::Tagged(tag), + child, + ) => { + if config.optional_field { + deser_code.content.line("read_len.read_elems(1)?;"); + deser_code.read_len_used = true; } - tag_check - } else { - let mut tag_check = Block::new(format!( - "{}match {}.tag()?", - before_after.before, deserializer_name - )); - - let some_deser_code = self - .generate_deserialize( - types, - *child, - DeserializeBeforeAfter::new("", "", before_after.expects_result), - config.optional_field(false), - cli, - ) - .mark_and_extract_content(&mut deser_code); - if let Some(single_line) = some_deser_code.as_single_line() { - tag_check.line(format!("{tag} => {single_line},")); + let mut tag_check = if cli.preserve_encodings { + let mut tag_check = Block::new(format!( + "{}match {}.tag_sz()?", + before_after.before, deserializer_name + )); + config.final_exprs.push("Some(tag_enc)".to_owned()); + let some_deser_code = self + .generate_deserialize( + types, + *child, + DeserializeBeforeAfter::new("", "", before_after.expects_result), + config.optional_field(false), + cli, + ) + .mark_and_extract_content(&mut deser_code); + if let Some(single_line) = some_deser_code.as_single_line() { + tag_check.line(format!("({tag}, tag_enc) => {single_line},")); + } else { + let mut deser_block = Block::new(format!("({tag}, tag_enc) =>")); + deser_block.push_all(some_deser_code); + deser_block.after(","); + tag_check.push_block(deser_block); + } + tag_check } else { - let mut deser_block = Block::new(format!("{tag} =>")); - deser_block.push_all(some_deser_code); - deser_block.after(","); - tag_check.push_block(deser_block); - } - tag_check - }; - tag_check.line(&format!( + let mut tag_check = Block::new(format!( + "{}match {}.tag()?", + before_after.before, deserializer_name + )); + + let some_deser_code = self + .generate_deserialize( + types, + *child, + DeserializeBeforeAfter::new("", "", before_after.expects_result), + config.optional_field(false), + cli, + ) + .mark_and_extract_content(&mut deser_code); + if let Some(single_line) = some_deser_code.as_single_line() { + tag_check.line(format!("{tag} => {single_line},")); + } else { + let mut deser_block = Block::new(format!("{tag} =>")); + deser_block.push_all(some_deser_code); + deser_block.after(","); + tag_check.push_block(deser_block); + } + tag_check + }; + tag_check.line(&format!( "{} => {}Err(DeserializeFailure::TagMismatch{{ found: tag, expected: {} }}.into()),", if cli.preserve_encodings { "(tag, _enc)" } else { "tag" }, if before_after.expects_result { "" } else { "return " }, tag)); - tag_check.after(before_after.after); - deser_code.content.push_block(tag_check); - deser_code.throws = true; + tag_check.after(before_after.after); + deser_code.content.push_block(tag_check); + deser_code.throws = true; + } } } deser_code @@ -3140,6 +3326,7 @@ impl GenerationScope { name: &RustIdent, variants: &[EnumVariant], tag: Option, + config: &RustStructConfig, cli: &Cli, ) { // I don't believe this is even possible (wouldn't be a single CBOR value + nowhere to embed) @@ -3148,7 +3335,7 @@ impl GenerationScope { .iter() .all(|v| !matches!(v.data, EnumVariantData::Inlined(_)))); // Rust only - generate_enum(self, types, name, variants, None, true, tag, cli); + generate_enum(self, types, name, variants, None, true, tag, config, cli); if cli.wasm { // Generate a wrapper object that we will expose to wasm around this let mut wrapper = create_base_wasm_wrapper(self, types, name, true, cli); @@ -4075,22 +4262,6 @@ fn create_deserialize_impls( (deser_impl, deser_embedded_impl) } -fn push_rust_struct( - gen_scope: &mut GenerationScope, - types: &IntermediateTypes, - name: &RustIdent, - s: codegen::Struct, - s_impl: codegen::Impl, - ser_impl: codegen::Impl, - ser_embedded_impl: Option, -) { - gen_scope.rust(types, name).push_struct(s).push_impl(s_impl); - gen_scope.rust_serialize(types, name).push_impl(ser_impl); - if let Some(s) = ser_embedded_impl { - gen_scope.rust_serialize(types, name).push_impl(s); - } -} - // We need to execute field deserialization inside a closure in order to capture and annotate with the field name // without having to put error annotation inside of every single cbor_event call. fn make_err_annotate_block(annotation: &str, before: &str, after: &str) -> Block { @@ -4265,23 +4436,41 @@ struct EncodingField { /// this MUST be equivalent to the Default trait of the encoding field. /// This can be more concise though e.g. None for Option::default() default_expr: &'static str, + enc_conversion_before: &'static str, + enc_conversion_after: &'static str, + is_copy: bool, /// inner encodings - used for map/vec types #[allow(unused)] inner: Vec, } +impl EncodingField { + pub fn enc_conversion(&self, expr: &str) -> String { + format!( + "{}{}{}", + self.enc_conversion_before, expr, self.enc_conversion_after + ) + } +} + fn key_encoding_field(name: &str, key: &FixedValue) -> EncodingField { match key { FixedValue::Text(_) => EncodingField { field_name: format!("{name}_key_encoding"), type_name: "StringEncoding".to_owned(), default_expr: "StringEncoding::default()", + enc_conversion_before: "StringEncoding::from(", + enc_conversion_after: ")", + is_copy: false, inner: Vec::new(), }, FixedValue::Uint(_) => EncodingField { field_name: format!("{name}_key_encoding"), type_name: "Option".to_owned(), default_expr: "None", + enc_conversion_before: "Some(", + enc_conversion_after: ")", + is_copy: true, inner: Vec::new(), }, _ => unimplemented!(), @@ -4303,6 +4492,9 @@ fn encoding_fields( field_name: format!("{name}_default_present"), type_name: "bool".to_owned(), default_expr: "false", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: true, inner: Vec::new(), }); } @@ -4322,6 +4514,9 @@ fn encoding_fields_impl( field_name: format!("{name}_encoding"), type_name: "LenEncoding".to_owned(), default_expr: "LenEncoding::default()", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: true, inner: Vec::new(), }; let inner_encs = @@ -4347,6 +4542,9 @@ fn encoding_fields_impl( field_name: format!("{name}_elem_encodings"), type_name: format!("Vec<{type_name_elem}>"), default_expr: "Vec::new()", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: false, inner: inner_encs, }, ] @@ -4357,6 +4555,9 @@ fn encoding_fields_impl( field_name: format!("{name}_encoding"), type_name: "LenEncoding".to_owned(), default_expr: "LenEncoding::default()", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: true, inner: Vec::new(), }]; let key_encs = encoding_fields_impl(types, &format!("{name}_key"), (&**k).into(), cli); @@ -4384,6 +4585,9 @@ fn encoding_fields_impl( type_name_value ), default_expr: "BTreeMap::new()", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: false, inner: key_encs, }); } @@ -4409,6 +4613,9 @@ fn encoding_fields_impl( type_name_value ), default_expr: "BTreeMap::new()", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: false, inner: val_encs, }); } @@ -4419,6 +4626,9 @@ fn encoding_fields_impl( field_name: format!("{name}_encoding"), type_name: "StringEncoding".to_owned(), default_expr: "StringEncoding::default()", + enc_conversion_before: "StringEncoding::from(", + enc_conversion_after: ")", + is_copy: false, inner: Vec::new(), }], Primitive::I8 @@ -4435,6 +4645,9 @@ fn encoding_fields_impl( field_name: format!("{name}_encoding"), type_name: "Option".to_owned(), default_expr: "None", + enc_conversion_before: "Some(", + enc_conversion_after: ")", + is_copy: true, inner: Vec::new(), }], Primitive::Bool => @@ -4470,8 +4683,8 @@ fn encoding_fields_impl( cli, ), }, - SerializingRustType::Root(ConceptualRustType::Alias(_, _), _cfg) => { - panic!("resolve types before calling this") + SerializingRustType::Root(ConceptualRustType::Alias(_, ty), _cfg) => { + encoding_fields_impl(types, name, (&**ty).into(), cli) } SerializingRustType::Root(ConceptualRustType::Optional(ty), _cfg) => { encoding_fields(types, name, ty, false, cli) @@ -4607,6 +4820,9 @@ fn generate_array_struct_serialization( }; let mut optional_array_ser_block = Block::new(optional_field_check); let mut config = SerializeConfig::new(field_expr, &field.name).expr_is_ref(expr_is_ref); + if let Some(custom_serialize) = &field.rule_metadata.custom_serialize { + config = config.custom_serialize(custom_serialize.clone()); + } if vars_in_self { config = config.encoding_var_in_option_struct("self.encodings") } else { @@ -4622,6 +4838,9 @@ fn generate_array_struct_serialization( ser_func.push_block(optional_array_ser_block); } else { let mut config = SerializeConfig::new(&field_expr, &field.name); + if let Some(custom_serialize) = &field.rule_metadata.custom_serialize { + config = config.custom_serialize(custom_serialize.clone()); + } if vars_in_self { config = config.encoding_var_in_option_struct("self.encodings") } else { @@ -4793,14 +5012,18 @@ fn generate_array_struct_deserialization( } else { (Cow::from("Some"), Cow::from("None")) }; + let mut deser_config = DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(true); + if let Some(custom_deserialize) = &field.rule_metadata.custom_deserialize { + deser_config = deser_config.custom_deserialize(custom_deserialize.clone()); + } gen_scope .generate_deserialize( types, (&field.rust_type).into(), DeserializeBeforeAfter::new("", "", true), - DeserializeConfig::new(&field.name) - .in_embedded(in_embedded) - .optional_field(true), + deser_config, cli, ) .annotate(&field.name, "", &format!(".map({some_map})")) @@ -4808,14 +5031,18 @@ fn generate_array_struct_deserialization( .add_to_code(&mut deser_code); type_check_else.line(format!("Ok({defaults})")); } else { + let mut deser_config = DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(true); + if let Some(custom_deserialize) = &field.rule_metadata.custom_deserialize { + deser_config = deser_config.custom_deserialize(custom_deserialize.clone()); + } gen_scope .generate_deserialize( types, (&field.rust_type).into(), DeserializeBeforeAfter::new("Some(", ")", false), - DeserializeConfig::new(&field.name) - .in_embedded(in_embedded) - .optional_field(true), + deser_config, cli, ) .wrap_in_block(type_check_block) @@ -4827,23 +5054,31 @@ fn generate_array_struct_deserialization( } else { // mandatory fields if cli.annotate_fields { + let mut deser_config = DeserializeConfig::new(&field.name).in_embedded(in_embedded); + if let Some(custom_deserialize) = &field.rule_metadata.custom_deserialize { + deser_config = deser_config.custom_deserialize(custom_deserialize.clone()); + } gen_scope .generate_deserialize( types, (&field.rust_type).into(), DeserializeBeforeAfter::new("", "", true), - DeserializeConfig::new(&field.name).in_embedded(in_embedded), + deser_config, cli, ) .annotate(&field.name, before.as_ref(), after.as_ref()) .add_to_code(&mut deser_code); } else { + let mut deser_config = DeserializeConfig::new(&field.name).in_embedded(in_embedded); + if let Some(custom_deserialize) = &field.rule_metadata.custom_deserialize { + deser_config = deser_config.custom_deserialize(custom_deserialize.clone()); + } gen_scope .generate_deserialize( types, (&field.rust_type).into(), DeserializeBeforeAfter::new(before.as_ref(), after.as_ref(), false), - DeserializeConfig::new(&field.name).in_embedded(in_embedded), + deser_config, cli, ) .add_to_code(&mut deser_code); @@ -4892,6 +5127,7 @@ fn codegen_struct( name: &RustIdent, tag: Option, record: &RustRecord, + config: &RustStructConfig, cli: &Cli, ) { let new_can_fail = record @@ -5164,675 +5400,745 @@ fn codegen_struct( native_impl.push_fn(native_new); // Serialization (via rust traits) - includes Deserialization too - let (ser_func, mut ser_impl, mut ser_embedded_impl) = create_serialize_impls( - name, - Some(record.rep), - tag, - &record.definite_info("self", false, types, cli), - len_encoding_var - .map(|var| { - format!("self.encodings.as_ref().map(|encs| encs.{var}).unwrap_or_default()") - }) - .as_deref(), - types.is_plain_group(name), - cli, - ); - let mut ser_func = match ser_embedded_impl { - Some(_) => { - ser_impl.push_fn(ser_func); - make_serialization_function("serialize_as_embedded_group", cli) - } - None => ser_func, - }; - let mut deser_code = DeserializationCode::default(); - let in_embedded = types.is_plain_group(name); - let ctor_block = match record.rep { - Representation::Array => { - generate_array_struct_serialization(gen_scope, types, record, true, &mut ser_func, cli); - let code = generate_array_struct_deserialization( - gen_scope, - types, - name, - record, - tag, - in_embedded, - true, - cli, - ); - code.deser_code.add_to_code(&mut deser_code); - let mut deser_ctor = Block::new(format!("Ok({name}")); - for (var, expr) in code.deser_ctor_fields { - if var == expr { - deser_ctor.line(format!("{var},")); - } else { - deser_ctor.line(format!("{var}: {expr},")); - } + if config.custom_serialize.is_none() || config.custom_deserialize.is_none() { + let (ser_func, mut ser_impl, mut ser_embedded_impl) = create_serialize_impls( + name, + Some(record.rep), + tag, + &record.definite_info("self", false, types, cli), + len_encoding_var + .map(|var| { + format!("self.encodings.as_ref().map(|encs| encs.{var}).unwrap_or_default()") + }) + .as_deref(), + types.is_plain_group(name), + cli, + ); + let mut ser_func = match ser_embedded_impl { + Some(_) => { + ser_impl.push_fn(ser_func); + make_serialization_function("serialize_as_embedded_group", cli) } - if !code.encoding_struct_ctor_fields.is_empty() { - let mut encoding_ctor_block = Block::new(format!("encodings: Some({name}Encoding")); - encoding_ctor_block.after("),"); - for (var, expr) in code.encoding_struct_ctor_fields { + None => ser_func, + }; + let mut deser_code = DeserializationCode::default(); + let in_embedded = types.is_plain_group(name); + let ctor_block = match record.rep { + Representation::Array => { + generate_array_struct_serialization( + gen_scope, + types, + record, + true, + &mut ser_func, + cli, + ); + let code = generate_array_struct_deserialization( + gen_scope, + types, + name, + record, + tag, + in_embedded, + true, + cli, + ); + code.deser_code.add_to_code(&mut deser_code); + let mut deser_ctor = Block::new(format!("Ok({name}")); + for (var, expr) in code.deser_ctor_fields { if var == expr { - encoding_ctor_block.line(format!("{var},")); + deser_ctor.line(format!("{var},")); } else { - encoding_ctor_block.line(format!("{var}: {expr},")); + deser_ctor.line(format!("{var}: {expr},")); } } - deser_ctor.push_block(encoding_ctor_block); - } - deser_ctor.after(")"); - deser_ctor - } - Representation::Map => { - let mut uint_field_deserializers = Vec::new(); - let mut text_field_deserializers = Vec::new(); - // (field_index, field, content) -- this is ordered by canonical order - let mut ser_content: Vec<(usize, &RustField, BlocksOrLines)> = Vec::new(); - if cli.preserve_encodings { - deser_code - .content - .line("let mut orig_deser_order = Vec::new();"); - } - // we default to canonical ordering here as the default ordering as that should be the most useful - // keep in mind this is always overwritten if you have cli.preserve_encodings enabled AND there was - // a deserialized encoding, otherwise we still use this by default. - for (field_index, field) in record.canonical_ordering() { - // to support maps with plain groups inside is very difficult as we cannot guarantee - // the order of fields so foo = {a, b, bar}, bar = (c, d) could have the order be - // {a, d, c, b}, {c, a, b, d}, etc which doesn't fit with the nature of deserialize_as_embedded_group - // A possible solution would be to take all fields into one big map, either in generation to begin with, - // or just for deserialization then constructing at the end with locals like a, b, bar_c, bar_d. - if let ConceptualRustType::Rust(ident) = &field.rust_type.conceptual_type { - if types.is_plain_group(ident) { - gen_scope.dont_generate_deserialize( - name, - format!( - "Map with plain group field {}: {}", - field.name, - field.rust_type.for_rust_member(types, false, cli) - ), - ); + if !code.encoding_struct_ctor_fields.is_empty() { + let mut encoding_ctor_block = + Block::new(format!("encodings: Some({name}Encoding")); + encoding_ctor_block.after("),"); + for (var, expr) in code.encoding_struct_ctor_fields { + if var == expr { + encoding_ctor_block.line(format!("{var},")); + } else { + encoding_ctor_block.line(format!("{var}: {expr},")); + } } + deser_ctor.push_block(encoding_ctor_block); } - // declare variables for deser loop + deser_ctor.after(")"); + deser_ctor + } + Representation::Map => { + let mut uint_field_deserializers = Vec::new(); + let mut text_field_deserializers = Vec::new(); + // (field_index, field, content) -- this is ordered by canonical order + let mut ser_content: Vec<(usize, &RustField, BlocksOrLines)> = Vec::new(); if cli.preserve_encodings { - for field_enc in encoding_fields( - types, - &field.name, - &field.rust_type.clone().resolve_aliases(), - true, - cli, - ) { + deser_code + .content + .line("let mut orig_deser_order = Vec::new();"); + } + // we default to canonical ordering here as the default ordering as that should be the most useful + // keep in mind this is always overwritten if you have cli.preserve_encodings enabled AND there was + // a deserialized encoding, otherwise we still use this by default. + for (field_index, field) in record.canonical_ordering() { + // to support maps with plain groups inside is very difficult as we cannot guarantee + // the order of fields so foo = {a, b, bar}, bar = (c, d) could have the order be + // {a, d, c, b}, {c, a, b, d}, etc which doesn't fit with the nature of deserialize_as_embedded_group + // A possible solution would be to take all fields into one big map, either in generation to begin with, + // or just for deserialization then constructing at the end with locals like a, b, bar_c, bar_d. + if let ConceptualRustType::Rust(ident) = &field.rust_type.conceptual_type { + if types.is_plain_group(ident) { + gen_scope.dont_generate_deserialize( + name, + format!( + "Map with plain group field {}: {}", + field.name, + field.rust_type.for_rust_member(types, false, cli) + ), + ); + } + } + // declare variables for deser loop + if cli.preserve_encodings { + for field_enc in encoding_fields( + types, + &field.name, + &field.rust_type.clone().resolve_aliases(), + true, + cli, + ) { + deser_code.content.line(&format!( + "let mut {} = {};", + field_enc.field_name, field_enc.default_expr + )); + } + let key_enc = key_encoding_field(&field.name, field.key.as_ref().unwrap()); deser_code.content.line(&format!( "let mut {} = {};", - field_enc.field_name, field_enc.default_expr + key_enc.field_name, key_enc.default_expr )); } - let key_enc = key_encoding_field(&field.name, field.key.as_ref().unwrap()); - deser_code.content.line(&format!( - "let mut {} = {};", - key_enc.field_name, key_enc.default_expr - )); - } - if field.rust_type.is_fixed_value() { - deser_code - .content - .line(&format!("let mut {}_present = false;", field.name)); - } else { - deser_code - .content - .line(&format!("let mut {} = None;", field.name)); - } - let (data_name, expr_is_ref) = - if field.optional && field.rust_type.config.default.is_none() { - (String::from("field"), true) + if field.rust_type.is_fixed_value() { + deser_code + .content + .line(&format!("let mut {}_present = false;", field.name)); } else { - (format!("self.{}", field.name), false) - }; - - let key = field.key.clone().unwrap(); - // deserialize key + value - let mut deser_block = match &key { - FixedValue::Uint(x) => { - if cli.preserve_encodings { - Block::new(format!("({x}, key_enc) => ")) + deser_code + .content + .line(&format!("let mut {} = None;", field.name)); + } + let (data_name, expr_is_ref) = + if field.optional && field.rust_type.config.default.is_none() { + (String::from("field"), true) } else { - Block::new(format!("{x} => ")) + (format!("self.{}", field.name), false) + }; + + let key = field.key.clone().unwrap(); + // deserialize key + value + let mut deser_block = match &key { + FixedValue::Uint(x) => { + if cli.preserve_encodings { + Block::new(format!("({x}, key_enc) => ")) + } else { + Block::new(format!("{x} => ")) + } } - } - FixedValue::Text(x) => Block::new(format!("\"{x}\" => ")), - _ => panic!( - "unsupported map key type for {}.{}: {:?}", - name, field.name, key - ), - }; - deser_block.after(","); - let mut deser_block_code = DeserializationCode::default(); - let key_in_rust = match &key { - FixedValue::Uint(x) => format!("Key::Uint({x})"), - FixedValue::Text(x) => format!("Key::Str(\"{x}\".into())"), - _ => unimplemented!(), - }; - if cli.preserve_encodings { - let mut dup_check = if field.rust_type.is_fixed_value() { - Block::new(format!("if {}_present", field.name)) - } else { - Block::new(format!("if {}.is_some()", field.name)) + FixedValue::Text(x) => Block::new(format!("\"{x}\" => ")), + _ => panic!( + "unsupported map key type for {}.{}: {:?}", + name, field.name, key + ), }; - dup_check.line(&format!( - "return Err(DeserializeFailure::DuplicateKey({key_in_rust}).into());" - )); - deser_block_code.content.push_block(dup_check); - - let temp_var_prefix = format!("tmp_{}", field.name); - let var_names_str = - encoding_var_names_str(types, &temp_var_prefix, &field.rust_type, cli); - if cli.annotate_fields { - let (before, after) = if var_names_str.is_empty() { - ("".to_owned(), "?") + deser_block.after(","); + let mut deser_block_code = DeserializationCode::default(); + let key_in_rust = match &key { + FixedValue::Uint(x) => format!("Key::Uint({x})"), + FixedValue::Text(x) => format!("Key::Str(\"{x}\".into())"), + _ => unimplemented!(), + }; + if cli.preserve_encodings { + let mut dup_check = if field.rust_type.is_fixed_value() { + Block::new(format!("if {}_present", field.name)) } else { - (format!("let {var_names_str} = "), "?;") + Block::new(format!("if {}.is_some()", field.name)) }; - gen_scope - .generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new("", "", true), - DeserializeConfig::new(&field.name) - .in_embedded(in_embedded) - .optional_field(field.optional), - cli, - ) - .annotate(&field.name, &before, after) - .add_to_code(&mut deser_block_code); - } else { - let (before, after) = if var_names_str.is_empty() { - ("".to_owned(), "") + dup_check.line(&format!( + "return Err(DeserializeFailure::DuplicateKey({key_in_rust}).into());" + )); + deser_block_code.content.push_block(dup_check); + + let temp_var_prefix = format!("tmp_{}", field.name); + let var_names_str = + encoding_var_names_str(types, &temp_var_prefix, &field.rust_type, cli); + if cli.annotate_fields { + let (before, after) = if var_names_str.is_empty() { + ("".to_owned(), "?") + } else { + (format!("let {var_names_str} = "), "?;") + }; + let mut deser_config = DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(field.optional); + if let Some(custom_deserialize) = + &field.rule_metadata.custom_deserialize + { + deser_config = + deser_config.custom_deserialize(custom_deserialize.clone()); + } + gen_scope + .generate_deserialize( + types, + (&field.rust_type).into(), + DeserializeBeforeAfter::new("", "", true), + deser_config, + cli, + ) + .annotate(&field.name, &before, after) + .add_to_code(&mut deser_block_code); } else { - (format!("let {var_names_str} = "), ";") - }; - gen_scope - .generate_deserialize( + let (before, after) = if var_names_str.is_empty() { + ("".to_owned(), "") + } else { + (format!("let {var_names_str} = "), ";") + }; + let mut deser_config = DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(field.optional); + if let Some(custom_deserialize) = + &field.rule_metadata.custom_deserialize + { + deser_config = + deser_config.custom_deserialize(custom_deserialize.clone()); + } + gen_scope + .generate_deserialize( + types, + (&field.rust_type).into(), + DeserializeBeforeAfter::new(&before, after, false), + deser_config, + cli, + ) + .add_to_code(&mut deser_block_code); + } + // Due to destructuring assignemnt (RFC 372 / 71156) being unstable we're forced to use temporaries then reassign after + // which is not ideal but doing the assignment inside the lambda or otherwise has issues where it's putting lots of + // context-sensitive logic into generate_deserialize and you would need to declare temporaries in most cases anyway + // as cbor_event encoding-aware functions return tuples which just pushes the problem there instead. + // We might be able to write a nice way around this in the annotate_fields=false, preserve_encodings=true case + // but I don't think anyone (or many) would care about this as it's incredibly niche + // (annotate_fields=false would be for minimizing code size but then preserve_encodings=true generates way more code) + if field.rust_type.is_fixed_value() { + deser_block_code + .content + .line(&format!("{}_present = true;", field.name)); + } else { + deser_block_code + .content + .line(&format!("{} = Some(tmp_{});", field.name, field.name)); + } + for enc_field in encoding_fields( + types, + &field.name, + &field.rust_type.clone().resolve_aliases(), + false, + cli, + ) { + deser_block_code.content.line(&format!( + "{} = tmp_{};", + enc_field.field_name, enc_field.field_name + )); + } + } else if field.rust_type.is_fixed_value() { + let mut dup_check = Block::new(format!("if {}_present", field.name)); + dup_check.line(&format!( + "return Err(DeserializeFailure::DuplicateKey({key_in_rust}).into());" + )); + deser_block_code.content.push_block(dup_check); + // only does verification and sets the field_present bool to do error checking later + if cli.annotate_fields { + let mut deser_config = DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(field.optional); + if let Some(custom_deserialize) = + &field.rule_metadata.custom_deserialize + { + deser_config = + deser_config.custom_deserialize(custom_deserialize.clone()); + } + let mut err_deser = gen_scope.generate_deserialize( types, (&field.rust_type).into(), - DeserializeBeforeAfter::new(&before, after, false), - DeserializeConfig::new(&field.name) - .in_embedded(in_embedded) - .optional_field(field.optional), + DeserializeBeforeAfter::new("", "", false), + deser_config, cli, - ) - .add_to_code(&mut deser_block_code); - } - // Due to destructuring assignemnt (RFC 372 / 71156) being unstable we're forced to use temporaries then reassign after - // which is not ideal but doing the assignment inside the lambda or otherwise has issues where it's putting lots of - // context-sensitive logic into generate_deserialize and you would need to declare temporaries in most cases anyway - // as cbor_event encoding-aware functions return tuples which just pushes the problem there instead. - // We might be able to write a nice way around this in the annotate_fields=false, preserve_encodings=true case - // but I don't think anyone (or many) would care about this as it's incredibly niche - // (annotate_fields=false would be for minimizing code size but then preserve_encodings=true generates way more code) - if field.rust_type.is_fixed_value() { - deser_block_code - .content - .line(&format!("{}_present = true;", field.name)); - } else { - deser_block_code - .content - .line(&format!("{} = Some(tmp_{});", field.name, field.name)); - } - for enc_field in encoding_fields( - types, - &field.name, - &field.rust_type.clone().resolve_aliases(), - false, - cli, - ) { - deser_block_code.content.line(&format!( - "{} = tmp_{};", - enc_field.field_name, enc_field.field_name - )); - } - } else if field.rust_type.is_fixed_value() { - let mut dup_check = Block::new(format!("if {}_present", field.name)); - dup_check.line(&format!( - "return Err(DeserializeFailure::DuplicateKey({key_in_rust}).into());" - )); - deser_block_code.content.push_block(dup_check); - // only does verification and sets the field_present bool to do error checking later - if cli.annotate_fields { - let mut err_deser = gen_scope.generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new("", "", false), - DeserializeConfig::new(&field.name) + ); + err_deser.content.line("Ok(true)"); + err_deser + .annotate(&field.name, &format!("{}_present = ", field.name), "?;") + .add_to_code(&mut deser_block_code); + } else { + let mut deser_config = DeserializeConfig::new(&field.name) .in_embedded(in_embedded) - .optional_field(field.optional), - cli, - ); - err_deser.content.line("Ok(true)"); - err_deser - .annotate(&field.name, &format!("{}_present = ", field.name), "?;") - .add_to_code(&mut deser_block_code); + .optional_field(field.optional); + if let Some(custom_deserialize) = + &field.rule_metadata.custom_deserialize + { + deser_config = + deser_config.custom_deserialize(custom_deserialize.clone()); + } + gen_scope + .generate_deserialize( + types, + (&field.rust_type).into(), + DeserializeBeforeAfter::new("", "", false), + deser_config, + cli, + ) + .add_to_code(&mut deser_block_code); + deser_block_code + .content + .line(&format!("{}_present = true;", field.name)); + } } else { - gen_scope - .generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new("", "", false), - DeserializeConfig::new(&field.name) - .in_embedded(in_embedded) - .optional_field(field.optional), - cli, - ) - .add_to_code(&mut deser_block_code); + let mut dup_check = Block::new(format!("if {}.is_some()", field.name)); + dup_check.line(&format!( + "return Err(DeserializeFailure::DuplicateKey({key_in_rust}).into());" + )); + deser_block_code.content.push_block(dup_check); + if cli.annotate_fields { + let mut deser_config = DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(field.optional); + if let Some(custom_deserialize) = + &field.rule_metadata.custom_deserialize + { + deser_config = + deser_config.custom_deserialize(custom_deserialize.clone()); + } + gen_scope + .generate_deserialize( + types, + (&field.rust_type).into(), + DeserializeBeforeAfter::new("", "", true), + deser_config, + cli, + ) + .annotate(&field.name, &format!("{} = Some(", field.name), "?);") + .add_to_code(&mut deser_block_code); + } else { + let mut deser_config = DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(field.optional); + if let Some(custom_deserialize) = + &field.rule_metadata.custom_deserialize + { + deser_config = + deser_config.custom_deserialize(custom_deserialize.clone()); + } + gen_scope + .generate_deserialize( + types, + (&field.rust_type).into(), + DeserializeBeforeAfter::new( + &format!("{} = Some(", field.name), + ");", + false, + ), + deser_config, + cli, + ) + .add_to_code(&mut deser_block_code); + } + } + if cli.preserve_encodings { + let key_encoding = key_encoding_field(&field.name, &key); deser_block_code .content - .line(&format!("{}_present = true;", field.name)); + .line(&format!( + "{} = {};", + key_encoding.field_name, + key_encoding.enc_conversion("key_enc") + )) + .line(&format!("orig_deser_order.push({field_index});")); } - } else { - let mut dup_check = Block::new(format!("if {}.is_some()", field.name)); - dup_check.line(&format!( - "return Err(DeserializeFailure::DuplicateKey({key_in_rust}).into());" - )); - deser_block_code.content.push_block(dup_check); - if cli.annotate_fields { - gen_scope - .generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new("", "", true), - DeserializeConfig::new(&field.name) - .in_embedded(in_embedded) - .optional_field(field.optional), + + // serialize key + let mut map_ser_content = BlocksOrLines::default(); + let serialize_config = SerializeConfig::new(&data_name, &field.name) + .expr_is_ref(expr_is_ref) + .encoding_var_in_option_struct("self.encodings"); + let key_encoding_var = + serialize_config.encoding_var(Some("key"), key.encoding_var_is_copy(types)); + + deser_block + .push_all(deser_block_code.mark_and_extract_content(&mut deser_code)); + match &key { + FixedValue::Uint(x) => { + let expr = format!("{x}u64"); + write_using_sz( + &mut map_ser_content, + "write_unsigned_integer", + "serializer", + &expr, + &expr, + "?;", + &key_encoding_var, cli, - ) - .annotate(&field.name, &format!("{} = Some(", field.name), "?);") - .add_to_code(&mut deser_block_code); - } else { - gen_scope - .generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new( - &format!("{} = Some(", field.name), - ");", - false, - ), - DeserializeConfig::new(&field.name) - .in_embedded(in_embedded) - .optional_field(field.optional), + ); + uint_field_deserializers.push(deser_block); + } + FixedValue::Text(s) => { + write_string_sz( + &mut map_ser_content, + "write_text", + "serializer", + &format!("\"{s}\""), + false, + "?;", + &key_encoding_var, cli, - ) - .add_to_code(&mut deser_block_code); - } + ); + text_field_deserializers.push(deser_block); + } + _ => panic!( + "unsupported map key type for {}.{}: {:?}", + name, field.name, key + ), + }; + + // serialize value + gen_scope.generate_serialize( + types, + (&field.rust_type).into(), + &mut map_ser_content, + serialize_config, + cli, + ); + ser_content.push((field_index, field, map_ser_content)); } if cli.preserve_encodings { - let key_encoding_var = key_encoding_field(&field.name, &key).field_name; - let enc_conversion = match &key { - FixedValue::Uint(_) => "Some(key_enc)", - FixedValue::Text(_) => "StringEncoding::from(key_enc)", - _ => unimplemented!(), + let (check_canonical, serialization_order) = if cli.canonical_form { + let indices_str = record + .canonical_ordering() + .iter() + .map(|(i, _)| i.to_string()) + .collect::>() + .join(","); + ("!force_canonical && ", format!("vec![{indices_str}]")) + } else { + ("", format!("(0..{}).collect()", ser_content.len())) }; - deser_block_code - .content - .line(&format!("{key_encoding_var} = {enc_conversion};")) - .line(&format!("orig_deser_order.push({field_index});")); - } - - // serialize key - let mut map_ser_content = BlocksOrLines::default(); - let serialize_config = SerializeConfig::new(&data_name, &field.name) - .expr_is_ref(expr_is_ref) - .encoding_var_in_option_struct("self.encodings"); - let key_encoding_var = - serialize_config.encoding_var(Some("key"), key.encoding_var_is_copy(types)); - - deser_block.push_all(deser_block_code.mark_and_extract_content(&mut deser_code)); - match &key { - FixedValue::Uint(x) => { - let expr = format!("{x}u64"); - write_using_sz( - &mut map_ser_content, - "write_unsigned_integer", - "serializer", - &expr, - &expr, - "?;", - &key_encoding_var, - cli, - ); - uint_field_deserializers.push(deser_block); - } - FixedValue::Text(s) => { - write_string_sz( - &mut map_ser_content, - "write_text", - "serializer", - &format!("\"{s}\""), - false, - "?;", - &key_encoding_var, - cli, - ); - text_field_deserializers.push(deser_block); - } - _ => panic!( - "unsupported map key type for {}.{}: {:?}", - name, field.name, key - ), - }; - - // serialize value - gen_scope.generate_serialize( - types, - (&field.rust_type).into(), - &mut map_ser_content, - serialize_config, - cli, - ); - ser_content.push((field_index, field, map_ser_content)); - } - if cli.preserve_encodings { - let (check_canonical, serialization_order) = if cli.canonical_form { - let indices_str = record - .canonical_ordering() - .iter() - .map(|(i, _)| i.to_string()) - .collect::>() - .join(","); - ("!force_canonical && ", format!("vec![{indices_str}]")) - } else { - ("", format!("(0..{}).collect()", ser_content.len())) - }; - ser_func.line(format!( + ser_func.line(format!( "let deser_order = self.encodings.as_ref().filter(|encs| {}encs.orig_deser_order.len() == {}).map(|encs| encs.orig_deser_order.clone()).unwrap_or_else(|| {});", check_canonical, record.definite_info("self", false, types, cli), serialization_order)); - let mut ser_loop = Block::new("for field_index in deser_order"); - let mut ser_loop_match = Block::new("match field_index"); - for (field_index, field, content) in ser_content.into_iter() { - // TODO: while this would be nice we would need to either: - // 1) know this before we call gen_scope.generate_serialize() OR - // 2) strip that !is_end (?;) field from it which seems brittle - //if let Some(single_line) = content.as_single_line() { - // ser_loop_match.line(format!("{} => {},")); - //} else { - //} - let mut field_ser_block = - if field.optional && field.rust_type.config.default.is_none() { - Block::new(format!( - "{} => if let Some(field) = &self.{}", - field_index, field.name - )) - } else { - Block::new(format!("{field_index} =>")) - }; - field_ser_block.push_all(content); - ser_loop_match.push_block(field_ser_block); - } - ser_loop_match.line("_ => unreachable!()").after(";"); - ser_loop.push_block(ser_loop_match); - ser_func.push_block(ser_loop); - } else { - for (_field_index, field, content) in ser_content.into_iter() { - if field.optional { - let optional_ser_field_check = - if let Some(default_value) = &field.rust_type.config.default { - format!( - "if self.{} != {}", - field.name, - default_value.to_primitive_str_compare() - ) + let mut ser_loop = Block::new("for field_index in deser_order"); + let mut ser_loop_match = Block::new("match field_index"); + for (field_index, field, content) in ser_content.into_iter() { + // TODO: while this would be nice we would need to either: + // 1) know this before we call gen_scope.generate_serialize() OR + // 2) strip that !is_end (?;) field from it which seems brittle + //if let Some(single_line) = content.as_single_line() { + // ser_loop_match.line(format!("{} => {},")); + //} else { + //} + let mut field_ser_block = + if field.optional && field.rust_type.config.default.is_none() { + Block::new(format!( + "{} => if let Some(field) = &self.{}", + field_index, field.name + )) } else { - format!("if let Some(field) = &self.{}", field.name) + Block::new(format!("{field_index} =>")) }; - let mut optional_ser_field = Block::new(optional_ser_field_check); - optional_ser_field.push_all(content); - ser_func.push_block(optional_ser_field); - } else { - ser_func.push_all(content); + field_ser_block.push_all(content); + ser_loop_match.push_block(field_ser_block); } - } - } - // needs to be in one line rather than a block because Block::after() only takes a string - deser_code.content.line("let mut read = 0;"); - let mut deser_loop = make_deser_loop("len", "read", cli); - let mut type_match = Block::new("match raw.cbor_type()?"); - if uint_field_deserializers.is_empty() { - type_match.line("cbor_event::Type::UnsignedInteger => return Err(DeserializeFailure::UnknownKey(Key::Uint(raw.unsigned_integer()?)).into()),"); - } else { - let mut uint_match = if cli.preserve_encodings { - Block::new( - "cbor_event::Type::UnsignedInteger => match raw.unsigned_integer_sz()?", - ) + ser_loop_match.line("_ => unreachable!()").after(";"); + ser_loop.push_block(ser_loop_match); + ser_func.push_block(ser_loop); } else { - Block::new("cbor_event::Type::UnsignedInteger => match raw.unsigned_integer()?") - }; - for case in uint_field_deserializers { - uint_match.push_block(case); + for (_field_index, field, content) in ser_content.into_iter() { + if field.optional { + let optional_ser_field_check = + if let Some(default_value) = &field.rust_type.config.default { + format!( + "if self.{} != {}", + field.name, + default_value.to_primitive_str_compare() + ) + } else { + format!("if let Some(field) = &self.{}", field.name) + }; + let mut optional_ser_field = Block::new(optional_ser_field_check); + optional_ser_field.push_all(content); + ser_func.push_block(optional_ser_field); + } else { + ser_func.push_all(content); + } + } } - let unknown_key_decl = if cli.preserve_encodings { - "(unknown_key, _enc)" + // needs to be in one line rather than a block because Block::after() only takes a string + deser_code.content.line("let mut read = 0;"); + let mut deser_loop = make_deser_loop("len", "read", cli); + let mut type_match = Block::new("match raw.cbor_type()?"); + if uint_field_deserializers.is_empty() { + type_match.line("cbor_event::Type::UnsignedInteger => return Err(DeserializeFailure::UnknownKey(Key::Uint(raw.unsigned_integer()?)).into()),"); } else { - "unknown_key" - }; - uint_match.line(format!("{unknown_key_decl} => return Err(DeserializeFailure::UnknownKey(Key::Uint(unknown_key)).into()),")); - uint_match.after(","); - type_match.push_block(uint_match); - } - // we can't map text_sz() with String::as_str() to match it since that would return a reference to a temporary - // so we need to store it in a local and have an extra block to declare it - if text_field_deserializers.is_empty() { - type_match.line("cbor_event::Type::Text => return Err(DeserializeFailure::UnknownKey(Key::Str(raw.text()?)).into()),"); - } else if cli.preserve_encodings { - let mut outer_match = Block::new("cbor_event::Type::Text =>"); - outer_match.line("let (text_key, key_enc) = raw.text_sz()?;"); - let mut text_match = Block::new("match text_key.as_str()"); - for case in text_field_deserializers { - text_match.push_block(case); + let mut uint_match = if cli.preserve_encodings { + Block::new( + "cbor_event::Type::UnsignedInteger => match raw.unsigned_integer_sz()?", + ) + } else { + Block::new( + "cbor_event::Type::UnsignedInteger => match raw.unsigned_integer()?", + ) + }; + for case in uint_field_deserializers { + uint_match.push_block(case); + } + let unknown_key_decl = if cli.preserve_encodings { + "(unknown_key, _enc)" + } else { + "unknown_key" + }; + uint_match.line(format!("{unknown_key_decl} => return Err(DeserializeFailure::UnknownKey(Key::Uint(unknown_key)).into()),")); + uint_match.after(","); + type_match.push_block(uint_match); } - text_match.line("unknown_key => return Err(DeserializeFailure::UnknownKey(Key::Str(unknown_key.to_owned())).into()),"); - outer_match.after(","); - outer_match.push_block(text_match); - type_match.push_block(outer_match); - } else { - let mut text_match = - Block::new("cbor_event::Type::Text => match raw.text()?.as_str()"); - for case in text_field_deserializers { - text_match.push_block(case); + // we can't map text_sz() with String::as_str() to match it since that would return a reference to a temporary + // so we need to store it in a local and have an extra block to declare it + if text_field_deserializers.is_empty() { + type_match.line("cbor_event::Type::Text => return Err(DeserializeFailure::UnknownKey(Key::Str(raw.text()?)).into()),"); + } else if cli.preserve_encodings { + let mut outer_match = Block::new("cbor_event::Type::Text =>"); + outer_match.line("let (text_key, key_enc) = raw.text_sz()?;"); + let mut text_match = Block::new("match text_key.as_str()"); + for case in text_field_deserializers { + text_match.push_block(case); + } + text_match.line("unknown_key => return Err(DeserializeFailure::UnknownKey(Key::Str(unknown_key.to_owned())).into()),"); + outer_match.after(","); + outer_match.push_block(text_match); + type_match.push_block(outer_match); + } else { + let mut text_match = + Block::new("cbor_event::Type::Text => match raw.text()?.as_str()"); + for case in text_field_deserializers { + text_match.push_block(case); + } + text_match.line("unknown_key => return Err(DeserializeFailure::UnknownKey(Key::Str(unknown_key.to_owned())).into()),"); + text_match.after(","); + type_match.push_block(text_match); } - text_match.line("unknown_key => return Err(DeserializeFailure::UnknownKey(Key::Str(unknown_key.to_owned())).into()),"); - text_match.after(","); - type_match.push_block(text_match); - } - let mut special_match = Block::new("cbor_event::Type::Special => match len"); - special_match.line(format!( - "{} => return Err(DeserializeFailure::BreakInDefiniteLen.into()),", - cbor_event_len_n("_", cli) - )); - // TODO: this will need to change if we support Special values as keys (e.g. true / false) - let mut break_check = Block::new(format!( - "{} => match raw.special()?", - cbor_event_len_indef(cli) - )); - break_check.line("cbor_event::Special::Break => break,"); - break_check.line("_ => return Err(DeserializeFailure::EndingBreakMissing.into()),"); - break_check.after(","); - special_match.push_block(break_check); - special_match.after(","); - type_match.push_block(special_match); - type_match.line("other_type => return Err(DeserializeFailure::UnexpectedKeyType(other_type).into()),"); - deser_loop.push_block(type_match); - deser_loop.line("read += 1;"); - deser_code.content.push_block(deser_loop); - let mut ctor_block = Block::new("Ok(Self"); - // make sure the field is present, and unwrap the Option - for field in &record.fields { - if !field.optional { - let key = match &field.key { - Some(FixedValue::Uint(x)) => format!("Key::Uint({x})"), - Some(FixedValue::Text(x)) => format!("Key::Str(String::from(\"{x}\"))"), - None => unreachable!(), - _ => unimplemented!(), - }; - if field.rust_type.is_fixed_value() { - let mut mandatory_field_check = - Block::new(format!("if !{}_present", field.name)); - mandatory_field_check.line(format!( + let mut special_match = Block::new("cbor_event::Type::Special => match len"); + special_match.line(format!( + "{} => return Err(DeserializeFailure::BreakInDefiniteLen.into()),", + cbor_event_len_n("_", cli) + )); + // TODO: this will need to change if we support Special values as keys (e.g. true / false) + let mut break_check = Block::new(format!( + "{} => match raw.special()?", + cbor_event_len_indef(cli) + )); + break_check.line("cbor_event::Special::Break => break,"); + break_check.line("_ => return Err(DeserializeFailure::EndingBreakMissing.into()),"); + break_check.after(","); + special_match.push_block(break_check); + special_match.after(","); + type_match.push_block(special_match); + type_match.line("other_type => return Err(DeserializeFailure::UnexpectedKeyType(other_type).into()),"); + deser_loop.push_block(type_match); + deser_loop.line("read += 1;"); + deser_code.content.push_block(deser_loop); + let mut ctor_block = Block::new("Ok(Self"); + // make sure the field is present, and unwrap the Option + for field in &record.fields { + if !field.optional { + let key = match &field.key { + Some(FixedValue::Uint(x)) => format!("Key::Uint({x})"), + Some(FixedValue::Text(x)) => format!("Key::Str(String::from(\"{x}\"))"), + None => unreachable!(), + _ => unimplemented!(), + }; + if field.rust_type.is_fixed_value() { + let mut mandatory_field_check = + Block::new(format!("if !{}_present", field.name)); + mandatory_field_check.line(format!( "return Err(DeserializeFailure::MandatoryFieldMissing({key}).into());" )); - deser_code.content.push_block(mandatory_field_check); - } else { - let mut mandatory_field_check = - Block::new(format!("let {} = match {}", field.name, field.name)); - mandatory_field_check.line("Some(x) => x,"); + deser_code.content.push_block(mandatory_field_check); + } else { + let mut mandatory_field_check = + Block::new(format!("let {} = match {}", field.name, field.name)); + mandatory_field_check.line("Some(x) => x,"); - mandatory_field_check.line(format!("None => return Err(DeserializeFailure::MandatoryFieldMissing({key}).into()),")); - mandatory_field_check.after(";"); - deser_code.content.push_block(mandatory_field_check); - } - } else if let Some(default_value) = &field.rust_type.config.default { - if cli.preserve_encodings { - let mut default_present_check = Block::new(format!( - "if {} == Some({})", - field.name, - default_value.to_primitive_str_assign() - )); - default_present_check - .line(format!("{}_default_present = true;", field.name)); - deser_code.content.push_block(default_present_check); - } - match default_value { - FixedValue::Text(_) => { - // to avoid clippy::or_fun_call - deser_code.content.line(&format!( - "let {} = {}.unwrap_or_else(|| {});", - field.name, - field.name, - default_value.to_primitive_str_assign() - )); + mandatory_field_check.line(format!("None => return Err(DeserializeFailure::MandatoryFieldMissing({key}).into()),")); + mandatory_field_check.after(";"); + deser_code.content.push_block(mandatory_field_check); } - FixedValue::Bool(_) - | FixedValue::Nint(_) - | FixedValue::Null - | FixedValue::Float(_) - | FixedValue::Uint(_) => { - deser_code.content.line(&format!( - "let {} = {}.unwrap_or({});", - field.name, + } else if let Some(default_value) = &field.rust_type.config.default { + if cli.preserve_encodings { + let mut default_present_check = Block::new(format!( + "if {} == Some({})", field.name, default_value.to_primitive_str_assign() )); + default_present_check + .line(format!("{}_default_present = true;", field.name)); + deser_code.content.push_block(default_present_check); + } + match default_value { + FixedValue::Text(_) => { + // to avoid clippy::or_fun_call + deser_code.content.line(&format!( + "let {} = {}.unwrap_or_else(|| {});", + field.name, + field.name, + default_value.to_primitive_str_assign() + )); + } + FixedValue::Bool(_) + | FixedValue::Nint(_) + | FixedValue::Null + | FixedValue::Float(_) + | FixedValue::Uint(_) => { + deser_code.content.line(&format!( + "let {} = {}.unwrap_or({});", + field.name, + field.name, + default_value.to_primitive_str_assign() + )); + } } } + if !field.rust_type.is_fixed_value() { + ctor_block.line(format!("{},", field.name)); + } } - if !field.rust_type.is_fixed_value() { - ctor_block.line(format!("{},", field.name)); - } - } - if cli.preserve_encodings { - let mut encoding_ctor = Block::new(format!("encodings: Some({name}Encoding")); - if tag.is_some() { - encoding_ctor.line("tag_encoding: Some(tag_encoding),"); - } - encoding_ctor - .line("len_encoding,") - .line("orig_deser_order,"); - for field in record.fields.iter() { - let key_enc = key_encoding_field(&field.name, field.key.as_ref().unwrap()); - encoding_ctor.line(format!("{},", key_enc.field_name)); - for field_enc in encoding_fields( - types, - &field.name, - &field.rust_type.clone().resolve_aliases(), - true, - cli, - ) { - encoding_ctor.line(format!("{},", field_enc.field_name)); + if cli.preserve_encodings { + let mut encoding_ctor = Block::new(format!("encodings: Some({name}Encoding")); + if tag.is_some() { + encoding_ctor.line("tag_encoding: Some(tag_encoding),"); + } + encoding_ctor + .line("len_encoding,") + .line("orig_deser_order,"); + for field in record.fields.iter() { + let key_enc = key_encoding_field(&field.name, field.key.as_ref().unwrap()); + encoding_ctor.line(format!("{},", key_enc.field_name)); + for field_enc in encoding_fields( + types, + &field.name, + &field.rust_type.clone().resolve_aliases(), + true, + cli, + ) { + encoding_ctor.line(format!("{},", field_enc.field_name)); + } } + encoding_ctor.after("),"); + ctor_block.push_block(encoding_ctor); } - encoding_ctor.after("),"); - ctor_block.push_block(encoding_ctor); + ctor_block.after(")"); + ctor_block } - ctor_block.after(")"); - ctor_block - } - }; - let len_enc_var = len_encoding_var - .map(|var| format!("self.encodings.as_ref().map(|encs| encs.{var}).unwrap_or_default()")) - .unwrap_or_default(); - end_len(&mut ser_func, "serializer", &len_enc_var, true, cli); - match &mut ser_embedded_impl { - Some(ser_embedded_impl) => ser_embedded_impl.push_fn(ser_func), - None => ser_impl.push_fn(ser_func), - }; - let mut deser_scaffolding = BlocksOrLines::default(); - let (mut deser_impl, mut deser_embedded_impl) = create_deserialize_impls( - name, - Some(record.rep), - tag, - Some(record.cbor_len_info(types)), - types.is_plain_group(name), - len_encoding_var, - &mut deser_scaffolding, - cli, - ); - if deser_embedded_impl.is_none() { - // ending checks are included with embedded serialization setup - // since we are populating deserialize_as_embedded_group() and deserialize() - // is already complete - // but these checks must be done manually here *after* we populate deserialize() - add_deserialize_final_len_check( - &mut deser_code.content, + }; + let len_enc_var = len_encoding_var + .map(|var| { + format!("self.encodings.as_ref().map(|encs| encs.{var}).unwrap_or_default()") + }) + .unwrap_or_default(); + end_len(&mut ser_func, "serializer", &len_enc_var, true, cli); + match &mut ser_embedded_impl { + Some(ser_embedded_impl) => ser_embedded_impl.push_fn(ser_func), + None => ser_impl.push_fn(ser_func), + }; + let mut deser_scaffolding = BlocksOrLines::default(); + let (mut deser_impl, mut deser_embedded_impl) = create_deserialize_impls( + name, Some(record.rep), - record.cbor_len_info(types), + tag, + Some(record.cbor_len_info(types)), + types.is_plain_group(name), + len_encoding_var, + &mut deser_scaffolding, cli, ); - } - deser_code.content.push_block(ctor_block); + if deser_embedded_impl.is_none() { + // ending checks are included with embedded serialization setup + // since we are populating deserialize_as_embedded_group() and deserialize() + // is already complete + // but these checks must be done manually here *after* we populate deserialize() + add_deserialize_final_len_check( + &mut deser_code.content, + Some(record.rep), + record.cbor_len_info(types), + cli, + ); + } + deser_code.content.push_block(ctor_block); - if cli.annotate_fields { - deser_code = deser_code.annotate(name.as_ref(), "", ""); - } + if cli.annotate_fields { + deser_code = deser_code.annotate(name.as_ref(), "", ""); + } - if let Some(deser_embedded_impl) = &mut deser_embedded_impl { - let mut deser_f = make_deserialization_function("deserialize"); - deser_f.push_all(deser_scaffolding); - deser_impl.push_fn(deser_f); - let mut deser_embed_f = make_deserialization_function("deserialize_as_embedded_group"); - let read_len_arg = if deser_code.read_len_used { - "read_len" - } else { - "_read_len" - }; - deser_embed_f.arg(read_len_arg, "&mut CBORReadLen"); - if cli.preserve_encodings { - deser_embed_f.arg("len", "cbor_event::LenSz"); + if let Some(deser_embedded_impl) = &mut deser_embedded_impl { + let mut deser_f = make_deserialization_function("deserialize"); + deser_f.push_all(deser_scaffolding); + deser_impl.push_fn(deser_f); + let mut deser_embed_f = make_deserialization_function("deserialize_as_embedded_group"); + let read_len_arg = if deser_code.read_len_used { + "read_len" + } else { + "_read_len" + }; + deser_embed_f.arg(read_len_arg, "&mut CBORReadLen"); + if cli.preserve_encodings { + deser_embed_f.arg("len", "cbor_event::LenSz"); + } else { + deser_embed_f.arg("len", "cbor_event::Len"); + } + // this is expected when creating the final struct but wouldn't have been available + // otherwise as it is in the non-embedded deserialiation function + if cli.preserve_encodings { + deser_embed_f.line("let len_encoding = len.into();"); + } + deser_embed_f.push_all(deser_code.content); + deser_embedded_impl.push_fn(deser_embed_f); } else { - deser_embed_f.arg("len", "cbor_event::Len"); + let mut deser_f = make_deserialization_function("deserialize"); + deser_f.push_all(deser_scaffolding); + deser_f.push_all(deser_code.content); + deser_impl.push_fn(deser_f); } - // this is expected when creating the final struct but wouldn't have been available - // otherwise as it is in the non-embedded deserialiation function - if cli.preserve_encodings { - deser_embed_f.line("let len_encoding = len.into();"); + + if config.custom_serialize.is_none() { + gen_scope.rust_serialize(types, name).push_impl(ser_impl); + if let Some(s) = ser_embedded_impl { + gen_scope.rust_serialize(types, name).push_impl(s); + } + } + + // TODO: generic deserialize (might need backtracking) + if gen_scope.deserialize_generated(name) { + gen_scope.rust_serialize(types, name).push_impl(deser_impl); + if let Some(deser_embedded_impl) = deser_embedded_impl { + gen_scope + .rust_serialize(types, name) + .push_impl(deser_embedded_impl); + } } - deser_embed_f.push_all(deser_code.content); - deser_embedded_impl.push_fn(deser_embed_f); - } else { - let mut deser_f = make_deserialization_function("deserialize"); - deser_f.push_all(deser_scaffolding); - deser_f.push_all(deser_code.content); - deser_impl.push_fn(deser_f); } - push_rust_struct( - gen_scope, - types, - name, - native_struct, - native_impl, - ser_impl, - ser_embedded_impl, - ); + + gen_scope + .rust(types, name) + .push_struct(native_struct) + .push_impl(native_impl); + // for clippy we generate a Default when new takes no args. // We keep new() for consistency with other types. if new_arg_count == 0 { @@ -5844,17 +6150,9 @@ fn codegen_struct( .line("Self::new()"); gen_scope.rust(types, name).push_impl(default_impl); } - // TODO: generic deserialize (might need backtracking) - if gen_scope.deserialize_generated(name) { - gen_scope.rust_serialize(types, name).push_impl(deser_impl); - if let Some(deser_embedded_impl) = deser_embedded_impl { - gen_scope - .rust_serialize(types, name) - .push_impl(deser_embedded_impl); - } - } } +#[allow(clippy::too_many_arguments)] fn codegen_group_choices( gen_scope: &mut GenerationScope, types: &IntermediateTypes, @@ -5862,10 +6160,21 @@ fn codegen_group_choices( variants: &[EnumVariant], rep: Representation, tag: Option, + config: &RustStructConfig, cli: &Cli, ) { // rust inner enum - generate_enum(gen_scope, types, name, variants, Some(rep), false, tag, cli); + generate_enum( + gen_scope, + types, + name, + variants, + Some(rep), + false, + tag, + config, + cli, + ); // wasm wrapper if cli.wasm { @@ -6156,6 +6465,9 @@ impl EnumVariantInRust { field_name: "len_encoding".to_owned(), type_name: "LenEncoding".to_owned(), default_expr: "LenEncoding::default()", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: true, inner: Vec::new(), }); outer_vars += 1; @@ -6182,6 +6494,9 @@ impl EnumVariantInRust { field_name: "len_encoding".to_owned(), type_name: "LenEncoding".to_owned(), default_expr: "LenEncoding::default()", + enc_conversion_before: "", + enc_conversion_after: "", + is_copy: true, inner: Vec::new(), }); for field in record.fields.iter() { @@ -6319,6 +6634,7 @@ fn generate_c_style_enum( name: &RustIdent, variants: &[EnumVariant], tag: Option, + config: &RustStructConfig, cli: &Cli, ) -> bool { if tag.is_some() && cli.preserve_encodings { @@ -6354,7 +6670,13 @@ fn generate_c_style_enum( ) .vis("pub"); } - add_struct_derives(&mut e, types.used_as_key(name), true, false, cli); + add_struct_derives( + &mut e, + types.used_as_key(name), + true, + config.custom_json, + cli, + ); for variant in variants.iter() { e.new_variant(variant.name.to_string()); } @@ -6494,6 +6816,7 @@ fn generate_enum( rep: Option, generate_deserialize_directly: bool, tag: Option, + config: &RustStructConfig, cli: &Cli, ) { if cli.wasm { @@ -6514,7 +6837,13 @@ fn generate_enum( // instead of using create_serialize_impl() and having the length encoded there, we want to make it easier // to offer definite length encoding even if we're mixing plain group members and non-plain group members (or mixed length plain ones) // by potentially wrapping the choices with the array/map tag in the variant branch when applicable - add_struct_derives(&mut e, types.used_as_key(name), true, false, cli); + add_struct_derives( + &mut e, + types.used_as_key(name), + true, + config.custom_json, + cli, + ); let mut ser_impl = make_serialization_impl(name.as_ref(), cli); let mut ser_func = make_serialization_function("serialize", cli); if let Some(tag) = tag { @@ -7113,7 +7442,7 @@ fn generate_wrapper_struct( type_name: &RustIdent, field_type: &RustType, min_max: Option<(Option, Option)>, - custom_json: bool, + struct_config: &RustStructConfig, cli: &Cli, ) { if min_max.is_some() { @@ -7172,7 +7501,7 @@ fn generate_wrapper_struct( Cow::Owned(field_type.for_rust_member(types, false, cli)) }; - if !custom_json { + if !struct_config.custom_json { // serde Serialize / Deserialize if cli.json_serde_derives { let mut serde_ser_fn = codegen::Function::new("serialize"); @@ -7591,7 +7920,7 @@ fn generate_wrapper_struct( .push_impl(s_impl) .push_impl(from_impl) .push_impl(from_inner_impl); - if !custom_json { + if !struct_config.custom_json { if cli.json_serde_derives { gen_scope .rust(types, type_name) diff --git a/src/intermediate.rs b/src/intermediate.rs index 6233416..4229af9 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -4,6 +4,7 @@ use cddl::ast::parent::ParentVisitor; use std::borrow::Cow; use std::collections::{BTreeMap, BTreeSet}; +use crate::comment_ast::RuleMetadata; // TODO: move all of these generation specifics into generation.rs use crate::generation::table_type; use crate::utils::{ @@ -44,14 +45,27 @@ pub struct AliasInfo { pub base_type: RustType, pub gen_rust_alias: bool, pub gen_wasm_alias: bool, + pub rule_metadata: Option, } impl AliasInfo { - pub fn new(base_type: RustType, gen_rust_alias: bool, gen_wasm_alias: bool) -> Self { + pub fn new_manual(base_type: RustType, gen_rust_alias: bool, gen_wasm_alias: bool) -> Self { Self { base_type, gen_rust_alias, gen_wasm_alias, + rule_metadata: None, + } + } + + pub fn new_from_metadata(base_type: RustType, rule_metadata: RuleMetadata) -> Self { + let gen_rust_alias = !rule_metadata.no_alias; + let gen_wasm_alias = !rule_metadata.no_alias; + Self { + base_type, + gen_rust_alias, + gen_wasm_alias, + rule_metadata: Some(rule_metadata), } } } @@ -269,7 +283,7 @@ impl<'a> IntermediateTypes<'a> { let mut aliases = BTreeMap::::new(); let mut insert_alias = |name: &str, rust_type: RustType| { let ident = AliasIdent::new(CDDLIdent::new(name)); - aliases.insert(ident, AliasInfo::new(rust_type, false, false)); + aliases.insert(ident, AliasInfo::new_manual(rust_type, false, false)); }; insert_alias("uint", ConceptualRustType::Primitive(Primitive::U64).into()); insert_alias("nint", ConceptualRustType::Primitive(Primitive::N64).into()); @@ -441,7 +455,7 @@ impl<'a> IntermediateTypes<'a> { } self.register_type_alias( rust_struct.ident.clone(), - AliasInfo::new(map_type, true, false), + AliasInfo::new_manual(map_type, true, false), ) } RustStructType::Array { element_type } => { @@ -452,7 +466,7 @@ impl<'a> IntermediateTypes<'a> { } self.register_type_alias( rust_struct.ident.clone(), - AliasInfo::new(array_type, true, false), + AliasInfo::new_manual(array_type, true, false), ) } RustStructType::Wrapper { @@ -497,7 +511,7 @@ impl<'a> IntermediateTypes<'a> { // 2 separate types (array wrapper -> tag wrapper struct) self.register_rust_struct( parent_visitor, - RustStruct::new_array(array_type_ident, None, element_type.clone()), + RustStruct::new_array(array_type_ident, None, None, element_type.clone()), cli, ); } @@ -540,7 +554,11 @@ impl<'a> IntermediateTypes<'a> { // but wasm_bindgen can't work with it directly we assume the user will supply the correct mappings self.register_type_alias( instance_ident, - AliasInfo::new(ConceptualRustType::Rust(real_ident).into(), true, false), + AliasInfo::new_manual( + ConceptualRustType::Rust(real_ident).into(), + true, + false, + ), ); } } @@ -2147,15 +2165,24 @@ pub struct RustField { pub optional: bool, // None for array fields, Some for map fields. FixedValue for (de)serialization for map keys pub key: Option, + // comment DSL metadata applied to this field + pub rule_metadata: RuleMetadata, } impl RustField { - pub fn new(name: String, rust_type: RustType, optional: bool, key: Option) -> Self { + pub fn new( + name: String, + rust_type: RustType, + optional: bool, + key: Option, + rule_metadata: RuleMetadata, + ) -> Self { Self { name, rust_type, optional, key, + rule_metadata, } } @@ -2180,6 +2207,26 @@ pub enum RustStructCBORLen { OptionalFields(usize), } +#[derive(Clone, Debug, Default)] +pub struct RustStructConfig { + pub custom_json: bool, + pub custom_serialize: Option, + pub custom_deserialize: Option, +} + +impl From> for RustStructConfig { + fn from(rule_metadata: Option<&RuleMetadata>) -> Self { + match rule_metadata { + Some(rule_metadata) => Self { + custom_json: rule_metadata.custom_json, + custom_serialize: rule_metadata.custom_serialize.clone(), + custom_deserialize: rule_metadata.custom_deserialize.clone(), + }, + None => Self::default(), + } + } +} + // TODO: It would be nice to separate parsing the CDDL lib structs and code generation entirely. // We would just need to construct these structs (+ maybe the array/table wrapper types) separately and pass these into codegen. // This would also give us more access to this info without reparsing which could simplify code in some places. @@ -2189,6 +2236,7 @@ pub enum RustStructCBORLen { pub struct RustStruct { ident: RustIdent, tag: Option, + config: RustStructConfig, pub(crate) variant: RustStructType, } @@ -2212,7 +2260,6 @@ pub enum RustStructType { Wrapper { wrapped: RustType, min_max: Option<(Option, Option)>, - custom_json: bool, }, /// This is a no-op in generation but to prevent lookups of things in the prelude /// e.g. `int` from not being resolved while still being able to detect it when @@ -2225,10 +2272,16 @@ pub enum RustStructType { } impl RustStruct { - pub fn new_record(ident: RustIdent, tag: Option, record: RustRecord) -> Self { + pub fn new_record( + ident: RustIdent, + tag: Option, + rule_metadata: Option<&RuleMetadata>, + record: RustRecord, + ) -> Self { Self { ident, tag, + config: RustStructConfig::from(rule_metadata), variant: RustStructType::Record(record), } } @@ -2236,20 +2289,28 @@ impl RustStruct { pub fn new_table( ident: RustIdent, tag: Option, + rule_metadata: Option<&RuleMetadata>, domain: RustType, range: RustType, ) -> Self { Self { ident, tag, + config: RustStructConfig::from(rule_metadata), variant: RustStructType::Table { domain, range }, } } - pub fn new_array(ident: RustIdent, tag: Option, element_type: RustType) -> Self { + pub fn new_array( + ident: RustIdent, + tag: Option, + rule_metadata: Option<&RuleMetadata>, + element_type: RustType, + ) -> Self { Self { ident, tag, + config: RustStructConfig::from(rule_metadata), variant: RustStructType::Array { element_type }, } } @@ -2258,6 +2319,7 @@ impl RustStruct { pub fn new_type_choice( ident: RustIdent, tag: Option, + rule_metadata: Option<&RuleMetadata>, variants: Vec, cli: &Cli, ) -> Self { @@ -2280,12 +2342,14 @@ impl RustStruct { Self { ident, tag, + config: RustStructConfig::from(rule_metadata), variant: RustStructType::TypeChoice { variants }, } } else { Self { ident, tag, + config: RustStructConfig::from(rule_metadata), variant: RustStructType::CStyleEnum { variants }, } } @@ -2294,12 +2358,14 @@ impl RustStruct { pub fn new_group_choice( ident: RustIdent, tag: Option, + rule_metadata: Option<&RuleMetadata>, variants: Vec, rep: Representation, ) -> Self { Self { ident, tag, + config: RustStructConfig::from(rule_metadata), variant: RustStructType::GroupChoice { variants, rep }, } } @@ -2307,17 +2373,17 @@ impl RustStruct { pub fn new_wrapper( ident: RustIdent, tag: Option, + rule_metadata: Option<&RuleMetadata>, wrapped_type: RustType, min_max: Option<(Option, Option)>, - custom_json: bool, ) -> Self { Self { ident, tag, + config: RustStructConfig::from(rule_metadata), variant: RustStructType::Wrapper { wrapped: wrapped_type, min_max, - custom_json, }, } } @@ -2326,6 +2392,7 @@ impl RustStruct { Self { ident, tag: None, + config: RustStructConfig::default(), variant: RustStructType::Extern, } } @@ -2334,6 +2401,7 @@ impl RustStruct { Self { ident, tag: None, + config: RustStructConfig::default(), variant: RustStructType::RawBytesType, } } @@ -2346,6 +2414,10 @@ impl RustStruct { self.tag } + pub fn config(&self) -> &RustStructConfig { + &self.config + } + pub fn variant(&self) -> &RustStructType { &self.variant } diff --git a/src/parsing.rs b/src/parsing.rs index 1f8b00b..8751737 100644 --- a/src/parsing.rs +++ b/src/parsing.rs @@ -170,7 +170,7 @@ fn parse_type_choices( let rule_metadata = RuleMetadata::from(inner_type2.comments_after_type.as_ref()); types.register_type_alias( name.clone(), - AliasInfo::new(final_type, !rule_metadata.no_alias, !rule_metadata.no_alias), + AliasInfo::new_from_metadata(final_type, rule_metadata), ); } else { let rule_metadata = merge_metadata( @@ -189,7 +189,8 @@ fn parse_type_choices( types.mark_used_as_key(name.clone()); } let variants = create_variants_from_type_choices(types, parent_visitor, type_choices, cli); - let rust_struct = RustStruct::new_type_choice(name.clone(), tag, variants, cli); + let rust_struct = + RustStruct::new_type_choice(name.clone(), tag, Some(&rule_metadata), variants, cli); match generic_params { Some(params) => types.register_generic_def(GenericDef::new(params, rust_struct)), None => types.register_rust_struct(parent_visitor, rust_struct, cli), @@ -523,9 +524,9 @@ fn parse_type( RustStruct::new_wrapper( type_name.clone(), outer_tag, + Some(&rule_metadata), ranged_type, Some(min_max), - rule_metadata.custom_json, ), cli, ); @@ -533,10 +534,9 @@ fn parse_type( // matches to known rust type e.g. u32, i16, etc so just make an alias types.register_type_alias( type_name.clone(), - AliasInfo::new( + AliasInfo::new_from_metadata( ranged_type.tag_if(outer_tag), - !rule_metadata.no_alias, - !rule_metadata.no_alias, + rule_metadata, ), ); } @@ -545,10 +545,9 @@ fn parse_type( Some(Primitive::Bytes) => { types.register_type_alias( type_name.clone(), - AliasInfo::new( + AliasInfo::new_from_metadata( ty.as_bytes().tag_if(outer_tag), - !rule_metadata.no_alias, - !rule_metadata.no_alias, + rule_metadata, ), ); } @@ -561,11 +560,7 @@ fn parse_type( .tag_if(outer_tag); types.register_type_alias( type_name.clone(), - AliasInfo::new( - default_type, - !rule_metadata.no_alias, - !rule_metadata.no_alias, - ), + AliasInfo::new_from_metadata(default_type, rule_metadata), ); } } @@ -612,19 +607,18 @@ fn parse_type( RustStruct::new_wrapper( type_name.clone(), None, + Some(&rule_metadata), concrete_type, None, - rule_metadata.custom_json, ), cli, ); } else { types.register_type_alias( type_name.clone(), - AliasInfo::new( + AliasInfo::new_from_metadata( concrete_type, - !rule_metadata.no_alias, - !rule_metadata.no_alias, + rule_metadata, ), ); } @@ -710,11 +704,7 @@ fn parse_type( }; types.register_type_alias( type_name.clone(), - AliasInfo::new( - base_type.tag_if(outer_tag), - !rule_metadata.no_alias, - !rule_metadata.no_alias, - ), + AliasInfo::new_from_metadata(base_type.tag_if(outer_tag), rule_metadata), ); } Type2::UintValue { value, .. } => { @@ -733,23 +723,18 @@ fn parse_type( }; types.register_type_alias( type_name.clone(), - AliasInfo::new( - base_type.tag_if(outer_tag), - !rule_metadata.no_alias, - !rule_metadata.no_alias, - ), + AliasInfo::new_from_metadata(base_type.tag_if(outer_tag), rule_metadata), ); } Type2::TextValue { value, .. } => { types.register_type_alias( type_name.clone(), - AliasInfo::new( + AliasInfo::new_from_metadata( RustType::new(ConceptualRustType::Fixed(FixedValue::Text( value.to_string(), ))) .tag_if(outer_tag), - !rule_metadata.no_alias, - !rule_metadata.no_alias, + rule_metadata, ), ); } @@ -769,7 +754,7 @@ fn parse_type( }; types.register_type_alias( type_name.clone(), - AliasInfo::new(base_type.tag_if(outer_tag), true, true), + AliasInfo::new_from_metadata(base_type.tag_if(outer_tag), rule_metadata), ); } x => { @@ -1082,6 +1067,24 @@ fn group_entry_to_raw_field_name(entry: &GroupEntry) -> Option { } } +fn group_entry_rule_metadata(entry: &GroupEntry, optional_comma: &OptionalComma) -> RuleMetadata { + let entry_trailing_comments = match entry { + GroupEntry::ValueMemberKey { + trailing_comments, .. + } => trailing_comments, + GroupEntry::TypeGroupname { + trailing_comments, .. + } => trailing_comments, + GroupEntry::InlineGroup { group, .. } => panic!( + "not implemented (define a new struct for this!) = {}\n\n {:?}", + group, group + ), + }; + let combined_comments = + combine_comments(entry_trailing_comments, &optional_comma.trailing_comments); + metadata_from_comments(&combined_comments.unwrap_or_default()) +} + fn rust_type_from_type1( types: &mut IntermediateTypes, parent_visitor: &ParentVisitor, @@ -1286,6 +1289,9 @@ fn rust_type( cli, ) } else { + let rule_metadata = RuleMetadata::from( + get_comment_after(parent_visitor, &CDDLType::from(t), None).as_ref(), + ); if t.type_choices.len() == 2 { // T / null or null / T should map to Option let a = &t.type_choices[0].type1; @@ -1324,7 +1330,7 @@ fn rust_type( let combined_ident = RustIdent::new(CDDLIdent::new(&combined_name)); types.register_rust_struct( parent_visitor, - RustStruct::new_type_choice(combined_ident, None, variants, cli), + RustStruct::new_type_choice(combined_ident, None, Some(&rule_metadata), variants, cli), cli, ); types.new_type(&CDDLIdent::new(combined_name), cli) @@ -1411,6 +1417,7 @@ fn parse_record_from_group_choice( &mut generated_fields, optional_comma, ); + let rule_metadata = group_entry_rule_metadata(group_entry, optional_comma); // does not exist for fixed values importantly let field_type = group_entry_to_type(types, parent_visitor, group_entry, cli); if let ConceptualRustType::Rust(ident) = &field_type.conceptual_type { @@ -1423,7 +1430,7 @@ fn parse_record_from_group_choice( } Representation::Array => None, }; - RustField::new(field_name, field_type, optional_field, key) + RustField::new(field_name, field_type, optional_field, key, rule_metadata) }) .collect(); RustRecord { rep, fields } @@ -1440,21 +1447,30 @@ fn parse_group_choice( generic_params: Option>, cli: &Cli, ) { + let rule_metadata = RuleMetadata::from( + get_comment_after(parent_visitor, &CDDLType::from(group_choice), None).as_ref(), + ); let rust_struct = match parse_group_type(types, parent_visitor, group_choice, rep, cli) { GroupParsingType::HomogenousArray(element_type) => { // Array - homogeneous element type with proper occurence operator - RustStruct::new_array(name.clone(), tag, element_type) + RustStruct::new_array(name.clone(), tag, Some(&rule_metadata), element_type) } GroupParsingType::HomogenousMap(key_type, value_type) => { // Table map - homogeneous key/value types - RustStruct::new_table(name.clone(), tag, key_type, value_type) + RustStruct::new_table( + name.clone(), + tag, + Some(&rule_metadata), + key_type, + value_type, + ) } GroupParsingType::Heterogenous | GroupParsingType::WrappedBasicGroup(_) => { // Heterogenous map or array with defined key/value pairs in the cddl like a struct let record = parse_record_from_group_choice(types, rep, parent_visitor, group_choice, cli); // We need to store this in IntermediateTypes so we can refer from one struct to another. - RustStruct::new_record(name.clone(), tag, record) + RustStruct::new_record(name.clone(), tag, Some(&rule_metadata), record) } }; match generic_params { @@ -1574,9 +1590,12 @@ pub fn parse_group( } }) .collect(); + let rule_metadata = RuleMetadata::from( + get_comment_after(parent_visitor, &CDDLType::from(group), None).as_ref(), + ); types.register_rust_struct( parent_visitor, - RustStruct::new_group_choice(name.clone(), tag, variants, rep), + RustStruct::new_group_choice(name.clone(), tag, Some(&rule_metadata), variants, rep), cli, ); } diff --git a/src/test.rs b/src/test.rs index a94d2dc..375ce2c 100644 --- a/src/test.rs +++ b/src/test.rs @@ -6,8 +6,8 @@ fn run_test( dir: &str, options: &[&str], export_suffix: Option<&str>, - external_rust_file_path: Option, - external_wasm_file_path: Option, + external_rust_file_paths: &[std::path::PathBuf], + external_wasm_file_paths: &[std::path::PathBuf], input_is_dir: bool, test_deps: &[&str], ) { @@ -53,8 +53,8 @@ fn run_test( lib_rs .write_all("\nuse serialization::*;\n".as_bytes()) .unwrap(); - // copy external file in too (if needed) too - if let Some(external_rust_file_path) = external_rust_file_path { + // copy external files in too (if needed) too + for external_rust_file_path in external_rust_file_paths { let extern_rs = std::fs::read_to_string(external_rust_file_path).unwrap(); lib_rs.write_all("\n\n".as_bytes()).unwrap(); lib_rs.write_all(extern_rs.as_bytes()).unwrap(); @@ -103,7 +103,7 @@ fn run_test( let wasm_export_dir = test_path.join(format!("{export_path}/wasm")); let wasm_test_dir = test_path.join("tests_wasm.rs"); // copy external wasm defs if they exist - if let Some(external_wasm_file_path) = external_wasm_file_path { + for external_wasm_file_path in external_wasm_file_paths { println!("trying to open: {external_wasm_file_path:?}"); let mut wasm_lib_rs = std::fs::OpenOptions::new() .append(true) @@ -183,12 +183,15 @@ fn core_with_wasm() { let extern_wasm_path = std::path::PathBuf::from_str("tests") .unwrap() .join("external_wasm_defs"); + let custom_ser_path = std::path::PathBuf::from_str("tests") + .unwrap() + .join("custom_serialization"); run_test( "core", &[], Some("wasm"), - Some(extern_rust_path), - Some(extern_wasm_path), + &[extern_rust_path, custom_ser_path], + &[extern_wasm_path], false, &[], ); @@ -200,12 +203,15 @@ fn core_no_wasm() { let extern_rust_path = std::path::PathBuf::from_str("tests") .unwrap() .join("external_rust_defs"); + let custom_ser_path = std::path::PathBuf::from_str("tests") + .unwrap() + .join("custom_serialization"); run_test( "core", &["--wasm=false"], None, - Some(extern_rust_path), - None, + &[extern_rust_path, custom_ser_path], + &[], false, &[], ); @@ -217,8 +223,8 @@ fn comment_dsl() { "comment-dsl", &["--preserve-encodings=true"], None, - None, - None, + &[], + &[], false, &[], ); @@ -226,12 +232,16 @@ fn comment_dsl() { #[test] fn preserve_encodings() { + use std::str::FromStr; + let custom_ser_path = std::path::PathBuf::from_str("tests") + .unwrap() + .join("custom_serialization_preserve"); run_test( "preserve-encodings", &["--preserve-encodings=true"], None, - None, - None, + &[custom_ser_path], + &[], false, &[], ); @@ -243,8 +253,8 @@ fn canonical() { "canonical", &["--preserve-encodings=true", "--canonical-form=true"], None, - None, - None, + &[], + &[], false, &[], ); @@ -252,7 +262,7 @@ fn canonical() { #[test] fn rust_wasm_split() { - run_test("rust-wasm-split", &[], None, None, None, false, &[]); + run_test("rust-wasm-split", &[], None, &[], &[], false, &[]); } #[test] @@ -269,10 +279,10 @@ fn multifile() { "multifile", &[], None, - Some(extern_rust_path), - Some(extern_wasm_path), + &[extern_rust_path], + &[extern_wasm_path], true, - &[], + &["hex = \"0.4.3\""], ); } @@ -297,8 +307,8 @@ fn multifile_json_preserve() { "--json-schema-export=true", ], Some("json_preserve"), - Some(extern_rust_path), - Some(extern_wasm_path), + &[extern_rust_path], + &[extern_wasm_path], true, &[], ); @@ -317,8 +327,8 @@ fn raw_bytes() { "raw-bytes", &[], None, - Some(extern_rust_path), - Some(extern_wasm_path), + &[extern_rust_path], + &[extern_wasm_path], false, &[], ); @@ -337,8 +347,8 @@ fn raw_bytes_preserve() { "raw-bytes-preserve", &["--preserve-encodings=true"], None, - Some(extern_rust_path), - Some(extern_wasm_path), + &[extern_rust_path], + &[extern_wasm_path], false, &[], ); @@ -354,8 +364,8 @@ fn json() { "json", &["--json-serde-derives=true", "--json-schema-export=true"], None, - Some(extern_rust_path), - None, + &[extern_rust_path], + &[], false, &[], ); @@ -375,8 +385,8 @@ fn json_preserve() { "--json-schema-export=true", ], Some("preserve"), - Some(extern_rust_path), - None, + &[extern_rust_path], + &[], false, &[], ); diff --git a/tests/core/input.cddl b/tests/core/input.cddl index 25f00d5..f53521f 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -197,3 +197,13 @@ enum_opt_embed_fields = [ ; @name eg 1, ? overlapping_inlined, #6.13(13) ] + +custom_bytes = bytes ; @custom_serialize custom_serialize_bytes @custom_deserialize custom_deserialize_bytes + +struct_with_custom_serialization = [ + custom_bytes, + field: bytes, ; @custom_serialize custom_serialize_bytes @custom_deserialize custom_deserialize_bytes + overridden: custom_bytes, ; @custom_serialize write_hex_string @custom_deserialize read_hex_string + tagged1: #6.9(custom_bytes), + tagged2: #6.9(uint), ; @custom_serialize write_tagged_uint_str @custom_deserialize read_tagged_uint_str +] diff --git a/tests/core/tests.rs b/tests/core/tests.rs index 10937a3..740dd85 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -467,4 +467,29 @@ mod tests { let g2 = EnumOptEmbedFields::new_eg(None); deser_test(&g2); } + + #[test] + fn custom_serialization() { + let struct_with_custom_bytes = StructWithCustomSerialization::new( + vec![0xCA, 0xFE, 0xF0, 0x0D], + vec![0x03, 0x01, 0x04, 0x01], + vec![0xBA, 0xAD, 0xD0, 0x0D], + vec![0xDE, 0xAD, 0xBE, 0xEF], + 1024, + ); + use cbor_event::{Sz, StringLenSz}; + let bytes_special_enc = StringLenSz::Indefinite(vec![(1, Sz::Inline), (1, Sz::Inline), (1, Sz::Inline), (1, Sz::Inline)]); + deser_test(&struct_with_custom_bytes); + let expected_bytes = vec![ + arr_def(5), + cbor_bytes_sz(vec![0xCA, 0xFE, 0xF0, 0x0D], bytes_special_enc.clone()), + cbor_bytes_sz(vec![0x03, 0x01, 0x04, 0x01], bytes_special_enc.clone()), + cbor_string("baadd00d"), + cbor_tag(9), + cbor_bytes_sz(vec![0xDE, 0xAD, 0xBE, 0xEF], bytes_special_enc.clone()), + cbor_tag(9), + cbor_string("1024") + ].into_iter().flatten().clone().collect::>(); + assert_eq!(expected_bytes, struct_with_custom_bytes.to_cbor_bytes()); + } } diff --git a/tests/custom_serialization b/tests/custom_serialization new file mode 100644 index 0000000..8e5bb1e --- /dev/null +++ b/tests/custom_serialization @@ -0,0 +1,64 @@ +// writes bytes using indefinite encoding chunked into 1-byte parts +pub fn custom_serialize_bytes<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + bytes: &[u8], +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> { + serializer.write_raw_bytes(&[0x5f])?; + for byte in bytes { + serializer.write_bytes(&[*byte])?; + } + serializer.write_special(cbor_event::Special::Break) +} + +// read bytes and verify the 1-byte chunking of custom_serialize_bytes() +pub fn custom_deserialize_bytes( + raw: &mut cbor_event::de::Deserializer, +) -> Result, DeserializeError> { + let (bytes, bytes_enc) = raw.bytes_sz()?; + match bytes_enc { + cbor_event::StringLenSz::Len(_sz) => Err(DeserializeFailure::CBOR(cbor_event::Error::CustomError("custom_deserialize_bytes(): needs indefinite chunking".to_owned())).into()), + cbor_event::StringLenSz::Indefinite(chunks) => { + for (chunk_len, _chunk_sz) in chunks.iter() { + if *chunk_len != 1 { + return Err(DeserializeFailure::CBOR(cbor_event::Error::CustomError(format!("custom_deserialize_bytes(): chunks need to be 1-len, found: {:?}", chunks))).into()); + } + } + Ok(bytes) + } + } +} + +// writes as hex text +pub fn write_hex_string<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + bytes: &[u8], +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> { + serializer.write_text(hex::encode(bytes)) +} + +// reads hex text to bytes +pub fn read_hex_string( + raw: &mut cbor_event::de::Deserializer, +) -> Result, DeserializeError> { + let text = raw.text()?; + hex::decode(text).map_err(|e| DeserializeFailure::InvalidStructure(Box::new(e)).into()) +} + +// must include the tag since @custom_serialize at field-level overrides everything +pub fn write_tagged_uint_str<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + uint: &u64, +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> { + serializer + .write_tag(9)? + .write_text(uint.to_string()) +} + +pub fn read_tagged_uint_str( + raw: &mut cbor_event::de::Deserializer, +) -> Result { + use std::str::FromStr; + let tag = raw.tag()?; + let text = raw.text()?; + u64::from_str(&text).map_err(|e| DeserializeFailure::InvalidStructure(Box::new(e)).into()) +} diff --git a/tests/custom_serialization_preserve b/tests/custom_serialization_preserve new file mode 100644 index 0000000..f03e937 --- /dev/null +++ b/tests/custom_serialization_preserve @@ -0,0 +1,86 @@ +// writes bytes using indefinite encoding chunked into 1-byte parts +pub fn custom_serialize_bytes<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + bytes: &[u8], + enc: &StringEncoding, +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> { + let szs = match enc { + StringEncoding::Indefinite(encs) => { + encs.iter().map(|(_l, e)| *e).chain(std::iter::repeat(cbor_event::Sz::Inline)).take(bytes.len()).collect::>() + } + _ => std::iter::repeat(cbor_event::Sz::Inline).take(bytes.len()).collect::>() + }; + serializer.write_raw_bytes(&[0x5f])?; + for (sz, byte) in szs.iter().zip(bytes.iter()) { + serializer.write_bytes_sz(&[*byte], cbor_event::StringLenSz::Len(*sz))?; + } + serializer.write_special(cbor_event::Special::Break) +} + +// read bytes and verify the 1-byte chunking of custom_serialize_bytes() +pub fn custom_deserialize_bytes( + raw: &mut cbor_event::de::Deserializer, +) -> Result<(Vec, StringEncoding), DeserializeError> { + let (bytes, bytes_enc) = raw.bytes_sz()?; + match &bytes_enc { + cbor_event::StringLenSz::Len(_sz) => Err(DeserializeFailure::CBOR(cbor_event::Error::CustomError("custom_deserialize_bytes(): needs indefinite chunking".to_owned())).into()), + cbor_event::StringLenSz::Indefinite(chunks) => { + for (chunk_len, _chunk_sz) in chunks.iter() { + if *chunk_len != 1 { + return Err(DeserializeFailure::CBOR(cbor_event::Error::CustomError(format!("custom_deserialize_bytes(): chunks need to be 1-len, found: {:?}", chunks))).into()); + } + } + Ok((bytes, bytes_enc.into())) + } + } +} + +// writes as hex text +pub fn write_hex_string<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + bytes: &[u8], + enc: &StringEncoding, +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> { + serializer.write_text_sz(hex::encode(bytes), enc.to_str_len_sz(bytes.len() as u64)) +} + +// reads hex text to bytes +pub fn read_hex_string( + raw: &mut cbor_event::de::Deserializer, +) -> Result<(Vec, StringEncoding), DeserializeError> { + let (text, text_enc) = raw.text_sz()?; + hex::decode(text) + .map(|bytes| (bytes, text_enc.into())) + .map_err(|e| DeserializeFailure::InvalidStructure(Box::new(e)).into()) +} + +// must include the tag since @custom_serialize at field-level overrides everything +pub fn write_tagged_uint_str<'se, W: std::io::Write>( + serializer: &'se mut cbor_event::se::Serializer, + uint: &u64, + tag_encoding: Option, + text_encoding: Option, +) -> cbor_event::Result<&'se mut cbor_event::se::Serializer> { + let uint_string = uint.to_string(); + let text_encoding = text_encoding + .map(|enc| crate::serialization::StringEncoding::Definite(enc)) + .unwrap_or(crate::serialization::StringEncoding::Canonical); + let uint_string_encoding = text_encoding.to_str_len_sz(uint_string.len() as u64); + serializer + .write_tag_sz(9, fit_sz(9, tag_encoding))? + .write_text_sz(uint_string, uint_string_encoding) +} + +pub fn read_tagged_uint_str( + raw: &mut cbor_event::de::Deserializer, +) -> Result<(u64, Option, Option), DeserializeError> { + use std::str::FromStr; + let (tag, tag_encoding) = raw.tag_sz()?; + let (text, text_encoding) = raw.text_sz()?; + match text_encoding { + cbor_event::StringLenSz::Indefinite(_) => Err(DeserializeFailure::CBOR(cbor_event::Error::CustomError(format!("We only support definite encodings in order to use the uint one"))).into()), + cbor_event::StringLenSz::Len(text_encoding_sz) => u64::from_str(&text) + .map(|uint| (uint, Some(tag_encoding), Some(text_encoding_sz))) + .map_err(|e| DeserializeFailure::InvalidStructure(Box::new(e)).into()), + } +} \ No newline at end of file diff --git a/tests/external_rust_defs b/tests/external_rust_defs index a75af8e..6186056 100644 --- a/tests/external_rust_defs +++ b/tests/external_rust_defs @@ -80,4 +80,3 @@ impl serialization::Deserialize for ExternGeneric T::deserialize(raw).map(Self) } } - diff --git a/tests/preserve-encodings/input.cddl b/tests/preserve-encodings/input.cddl index dfac09d..3038478 100644 --- a/tests/preserve-encodings/input.cddl +++ b/tests/preserve-encodings/input.cddl @@ -146,3 +146,13 @@ enum_opt_embed_fields = [ ; @name eg 1, ? overlapping_inlined, #6.13(13) ] + +custom_bytes = bytes ; @custom_serialize custom_serialize_bytes @custom_deserialize custom_deserialize_bytes + +struct_with_custom_serialization = [ + custom_bytes, + field: bytes, ; @custom_serialize custom_serialize_bytes @custom_deserialize custom_deserialize_bytes + overridden: custom_bytes, ; @custom_serialize write_hex_string @custom_deserialize read_hex_string + tagged1: #6.9(custom_bytes), + tagged2: #6.9(uint), ; @custom_serialize write_tagged_uint_str @custom_deserialize read_tagged_uint_str +] diff --git a/tests/preserve-encodings/tests.rs b/tests/preserve-encodings/tests.rs index db76713..fb44943 100644 --- a/tests/preserve-encodings/tests.rs +++ b/tests/preserve-encodings/tests.rs @@ -1046,4 +1046,32 @@ mod tests { } } } + + #[test] + fn custom_serialization() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + let str_8_encodings = vec![ + StringLenSz::Len(Sz::One), + StringLenSz::Len(Sz::Inline), + StringLenSz::Indefinite(vec![(3, Sz::Two), (5, Sz::One)]), + StringLenSz::Indefinite(vec![(0, Sz::Four), (4, Sz::Inline), (0, Sz::Inline), (4, Sz::Inline), (0, Sz::One)]), + ]; + for def_enc in &def_encodings { + let bytes_special_enc = StringLenSz::Indefinite(vec![(1, *def_enc); 4]); + for str_enc in &str_8_encodings { + let irregular_bytes = vec![ + arr_sz(5, *def_enc), + cbor_bytes_sz(vec![0xCA, 0xFE, 0xF0, 0x0D], bytes_special_enc.clone()), + cbor_bytes_sz(vec![0x03, 0x01, 0x04, 0x01], bytes_special_enc.clone()), + cbor_str_sz("baadd00d", str_enc.clone()), + cbor_tag(9), + cbor_bytes_sz(vec![0xDE, 0xAD, 0xBE, 0xEF], bytes_special_enc.clone()), + cbor_tag(9), + cbor_str_sz("10241024", StringLenSz::Len(*def_enc)) + ].into_iter().flatten().clone().collect::>(); + let from_bytes = StructWithCustomSerialization::from_cbor_bytes(&irregular_bytes).unwrap(); + assert_eq!(from_bytes.to_cbor_bytes(), irregular_bytes); + } + } + } }