From d6086cfeecf4f6b055440ae10e0de0b3b251434d Mon Sep 17 00:00:00 2001 From: rooooooooob Date: Wed, 9 Aug 2023 12:33:18 -0400 Subject: [PATCH] Deserialize for array groups with optional fields (#204) * Deserialize for array groups with optional fields Fixes #203 Fixes #154 as this issue popped up while making sure optional fields were very well supported. * Optional array fields with preserve-encodings=true --- src/generation.rs | 263 +++++++++++++++++++--------- src/intermediate.rs | 4 + tests/core/input.cddl | 13 +- tests/core/tests.rs | 55 ++++++ tests/deser_test | 6 + tests/preserve-encodings/input.cddl | 14 +- tests/preserve-encodings/tests.rs | 71 ++++++++ 7 files changed, 345 insertions(+), 81 deletions(-) diff --git a/src/generation.rs b/src/generation.rs index 3e47355..50959a4 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -700,9 +700,11 @@ impl GenerationScope { rust_struct.variant(), RustStructType::Array { .. } | RustStructType::Table { .. } ); - // the is_referenced check is for things like Int which are included by default + // The is_referenced check is for things like Int which are included by default // in order for the CDDL to parse but might not be used. - if !is_typedef && types.is_referenced(rust_ident) { + // However, we need to export other root types from the user's spec + if !is_typedef && (rust_ident.as_ref() != "Int" || types.is_referenced(rust_ident)) + { main_lines_by_file .entry(types.scope(rust_ident).clone()) .or_default() @@ -2042,6 +2044,10 @@ impl GenerationScope { _ => unimplemented!(), }; deser_code.throws = true; + // this block needs to evaluate to a Result even though it has no value + if !cli.preserve_encodings && before_after.expects_result { + deser_code.content.line("Ok(())"); + } } SerializingRustType::Root(ConceptualRustType::Primitive(p)) => { if config.optional_field { @@ -4260,6 +4266,10 @@ fn generate_array_struct_serialization( for field in record.fields.iter() { let field_expr = format!("{}{}", opt_self, field.name); if field.optional { + if field.rust_type.is_fixed_value() && !cli.preserve_encodings { + // we just want to skip this entirely if we aren't remembering enecodings + continue; + } let (optional_field_check, field_expr, expr_is_ref) = if let Some(default_value) = &field.rust_type.default { @@ -4354,73 +4364,171 @@ fn generate_array_struct_deserialization( let mut deser_code = DeserializationCode::default(); let mut deser_ctor_fields = vec![]; let mut encoding_struct_ctor_fields = vec![]; - for field in record.fields.iter() { - if field.optional { - gen_scope.dont_generate_deserialize( - name, - format!( - "Array with optional field {}: {}", - field.name, - field.rust_type.for_rust_member(types, false, cli) - ), - ); + for (field_index, field) in record.fields.iter().enumerate() { + let (before, after) = if cli.preserve_encodings { + let var_names_str = encoding_var_names_str(types, &field.name, &field.rust_type, cli); + if cli.annotate_fields { + ( + Cow::from(format!("let {var_names_str} = ")), + Cow::from("?;"), + ) + } else { + (Cow::from(format!("let {var_names_str} = ")), Cow::from(";")) + } + } else if field.rust_type.is_fixed_value() { + // don't set anything, only verify data + if cli.annotate_fields { + (Cow::from(""), Cow::from("?;")) + } else { + (Cow::from(""), Cow::from("")) + } + } else if cli.annotate_fields { + (Cow::from(format!("let {} = ", field.name)), Cow::from("?;")) } else { - if cli.preserve_encodings { - let var_names_str = - encoding_var_names_str(types, &field.name, &field.rust_type, cli); - if cli.annotate_fields { - gen_scope - .generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new("", "", true), - DeserializeConfig::new(&field.name).in_embedded(in_embedded), - cli, - ) - .annotate(&field.name, &format!("let {var_names_str} = "), "?;") - .add_to_code(&mut deser_code); + (Cow::from(format!("let {} = ", field.name)), Cow::from(";")) + }; + if field.optional { + // we can support optional fields, but only when they're immediately non-ambiguous + // i.e. when the next type (possibly skipping subsequent optional fields) + // is different from the current type. + // Supporting the general case 100% is extremely complicated without a combinatorial + // backtrack but for most sane real-world cases this wouldn't be necessary. + // Think purposefully written edge-cases with multiple optional fields, possibly nested + // in other structs, and with many of the same types. + // e.g. [ ? uint, uint, ? (uint, text), ? text] + let field_cbor_types = field.rust_type.cbor_types(types); + let mut possibly_last_field = true; + for i in (field_index + 1)..record.fields.len() { + if record.fields[i] + .rust_type + .cbor_types(types) + .iter() + .any(|ct| field_cbor_types.contains(ct)) + { + gen_scope.dont_generate_deserialize( + name, + format!( + "Array struct with potentially-ambiguous optional field {}: {:?}", + field.name, field.rust_type, + ), + ); + } + if !record.fields[i].optional { + if i < record.fields.len() - 1 { + possibly_last_field = false; + } + break; + } + } + // we also need to be careful if we're possibly the last field in the CBOR + // buffer to avoid raw.cbor_type()? throwing an error for CBOR(NotEnough(0, 0)) + let type_check_cond = if field_cbor_types.len() == 1 { + let type_str = cbor_type_code_str(field_cbor_types[0]); + if possibly_last_field { + // We also need to be careful if the last one is a non-Break special + // and the array is encoded using indefinite encoding. + // There's no nice way to access this as Deserializer::special_break() consumes + // the byte so we'll just inline this ugly code instead + if field_cbor_types.contains(&cbor_event::Type::Special) { + "if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| cbor_event::Type::from(*byte) == cbor_event::Type::Special && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)".to_owned() + } else { + format!("if raw.cbor_type().map(|ty| ty == {type_str}).unwrap_or(false)") + } } else { - gen_scope - .generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new( - &format!("let {var_names_str} = "), - ";", - false, - ), - DeserializeConfig::new(&field.name).in_embedded(in_embedded), - cli, + format!("if raw.cbor_type()? == {type_str}") + } + } else { + let types_str = field_cbor_types + .iter() + .map(|ty| cbor_type_code_str(*ty)) + .collect::>() + .join(", "); + if possibly_last_field { + // We also need to be careful if the last one is a non-Break special + // and the array is encoded using indefinite encoding. + // There's no nice way to access this as Deserializer::special_break() consumes + // the byte so we'll just inline this ugly code instead + if field_cbor_types.contains(&cbor_event::Type::Special) { + format!( + "if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| vec![{types_str}].contains(&cbor_event::Type::from(*byte)) && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)", ) - .add_to_code(&mut deser_code); + } else { + format!("if raw.cbor_type().map(|ty| vec![{types_str}].contains(&ty)).unwrap_or(false)") + } + } else { + format!("if vec![{types_str}].contains(&raw.cbor_type()?)") } - } else if field.rust_type.is_fixed_value() { - // don't set anything, only verify data - if cli.annotate_fields { - let mut err_deser = gen_scope.generate_deserialize( + }; + let type_check_block = Block::new(format!("{before}{type_check_cond}")); + let mut type_check_else = Block::new("else"); + if cli.annotate_fields { + let enc_fields = if cli.preserve_encodings { + let resolved_rust_type = field.rust_type.clone().resolve_aliases(); + assert!( + !resolved_rust_type.is_fixed_value(), + "https://github.com/dcSpark/cddl-codegen/issues/205" + ); + encoding_fields(types, &field.name, &resolved_rust_type, false, cli) + } else { + vec![] + }; + let (some_map, defaults) = if !enc_fields.is_empty() { + let enc_names_str = enc_fields + .iter() + .map(|enc| enc.field_name.clone()) + .collect::>() + .join(", "); + ( + Cow::from(format!( + "|({}, {})| (Some({}), {})", + field.name, enc_names_str, field.name, enc_names_str + )), + Cow::from(format!( + "(None, {})", + enc_fields + .iter() + .map(|enc| enc.default_expr.to_owned()) + .collect::>() + .join(", ") + )), + ) + } else { + (Cow::from("Some"), Cow::from("None")) + }; + gen_scope + .generate_deserialize( types, (&field.rust_type).into(), DeserializeBeforeAfter::new("", "", true), - DeserializeConfig::new(&field.name).in_embedded(in_embedded), + DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(true), cli, - ); - // this block needs to evaluate to a Result even though it has no value - err_deser.content.line("Ok(())"); - err_deser - .annotate(&field.name, "", "?;") - .add_to_code(&mut deser_code); - } else { - gen_scope - .generate_deserialize( - types, - (&field.rust_type).into(), - DeserializeBeforeAfter::new("", "", false), - DeserializeConfig::new(&field.name).in_embedded(in_embedded), - cli, - ) - .add_to_code(&mut deser_code); - } - } else if cli.annotate_fields { + ) + .annotate(&field.name, "", &format!(".map({some_map})")) + .wrap_in_block(type_check_block) + .add_to_code(&mut deser_code); + type_check_else.line(format!("Ok({defaults})")); + } else { + gen_scope + .generate_deserialize( + types, + (&field.rust_type).into(), + DeserializeBeforeAfter::new("Some(", ")", false), + DeserializeConfig::new(&field.name) + .in_embedded(in_embedded) + .optional_field(true), + cli, + ) + .wrap_in_block(type_check_block) + .add_to_code(&mut deser_code); + type_check_else.line("None"); + } + type_check_else.after(after); + deser_code.content.push_block(type_check_else); + } else { + // mandatory fields + if cli.annotate_fields { gen_scope .generate_deserialize( types, @@ -4429,22 +4537,22 @@ fn generate_array_struct_deserialization( DeserializeConfig::new(&field.name).in_embedded(in_embedded), cli, ) - .annotate(&field.name, &format!("let {} = ", field.name), "?;") + .annotate(&field.name, before.as_ref(), after.as_ref()) .add_to_code(&mut deser_code); } else { gen_scope .generate_deserialize( types, (&field.rust_type).into(), - DeserializeBeforeAfter::new(&format!("let {} = ", field.name), ";", false), + DeserializeBeforeAfter::new(before.as_ref(), after.as_ref(), false), DeserializeConfig::new(&field.name).in_embedded(in_embedded), cli, ) .add_to_code(&mut deser_code); } - if !field.rust_type.is_fixed_value() { - deser_ctor_fields.push((field.name.clone(), field.name.clone())); - } + } + if !field.rust_type.is_fixed_value() { + deser_ctor_fields.push((field.name.clone(), field.name.clone())); } } if cli.preserve_encodings { @@ -4459,18 +4567,15 @@ fn generate_array_struct_deserialization( encoding_vars_output.push(("tag_encoding".to_owned(), "Some(tag_encoding)".to_owned())); } for field in record.fields.iter() { - // we don't support deserialization for optional fields so don't even bother - if !field.optional { - for field_enc in encoding_fields( - types, - &field.name, - &field.rust_type.clone().resolve_aliases(), - true, - cli, - ) { - encoding_vars_output - .push((field_enc.field_name.clone(), field_enc.field_name.clone())); - } + for field_enc in encoding_fields( + types, + &field.name, + &field.rust_type.clone().resolve_aliases(), + true, + cli, + ) { + encoding_vars_output + .push((field_enc.field_name.clone(), field_enc.field_name.clone())); } } } @@ -4607,7 +4712,7 @@ fn codegen_struct( gen_scope.dont_generate_deserialize( name, format!( - "field {}: {} couldn't generate serialize", + "field {}: {} couldn't generate deserialize", field.name, field.rust_type.for_rust_member(types, false, cli) ), diff --git a/src/intermediate.rs b/src/intermediate.rs index 14a9eaa..f360d69 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -2372,6 +2372,10 @@ impl RustRecord { let mut conditional_field_expr = String::new(); for field in &self.fields { if field.optional { + if !cli.preserve_encodings && field.rust_type.is_fixed_value() { + // we don't create fields for fixed values when preserve-encodings=false + continue; + } if !conditional_field_expr.is_empty() { conditional_field_expr.push_str(" + "); } diff --git a/tests/core/input.cddl b/tests/core/input.cddl index a946f8e..f10af52 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -134,4 +134,15 @@ overlapping0 = [0] overlapping1 = [0, uint] overlapping2 = [0, uint, text] -overlapping = overlapping0 / overlapping1 / overlapping2 \ No newline at end of file +overlapping = overlapping0 / overlapping1 / overlapping2 + +array_opt_fields = [ + ? x: 1.010101 + ? a: uint, + ? b: text, + c: nint, + ? d: text, + y: 3.14159265 + ? e: non_overlapping_type_choice_some + ? z: 2.71828 +] \ No newline at end of file diff --git a/tests/core/tests.rs b/tests/core/tests.rs index 15b2f40..8e4d61b 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -239,4 +239,59 @@ mod tests { deser_test(&NonOverlappingTypeChoiceSome::N64(10000)); deser_test(&NonOverlappingTypeChoiceSome::Text("Hello, World!".into())); } + + #[test] + fn array_opt_fields() { + let mut foo = ArrayOptFields::new(10); + for e in [None, Some(NonOverlappingTypeChoiceSome::U64(5)), Some(NonOverlappingTypeChoiceSome::N64(4)), Some(NonOverlappingTypeChoiceSome::Text("five".to_owned()))] { + for a in [false, true] { + for b in [false, true] { + for d in [false, true] { + // round-trip on non-constants + foo.a = if a { Some(0) } else { None }; + foo.b = if b { Some("hello, world".to_owned()) } else { None }; + foo.d = if d { Some("cddl-codegen".to_owned()) } else { None }; + foo.e = e.clone(); + deser_test(&foo); + // deser for constants too + for x in [false, true] { + for y in [false, true] { + for z in [false, true] { + let mut components = vec![vec![ARR_INDEF]]; + let bytes = vec![ + vec![ARR_INDEF] + ]; + if x { + components.push(cbor_float(1.010101)); + } + if a { + components.push(cbor_int(0, cbor_event::Sz::One)); + } + if b { + components.push(cbor_string("hello, world")); + } + // c + components.push(cbor_int(-10, cbor_event::Sz::One)); + if d { + components.push(cbor_string("cddl-codegen")); + } + // y + components.push(cbor_float(3.14159265)); + if let Some(e) = &e { + components.push(e.to_cbor_bytes()); + } + if z { + components.push(cbor_float(2.71828)); + } + components.push(vec![BREAK]); + let bytes = components.into_iter().flatten().clone().collect::>(); + let _ = ArrayOptFields::from_cbor_bytes(&bytes).unwrap(); + } + } + } + } + } + } + } + } } diff --git a/tests/deser_test b/tests/deser_test index 43d3298..9ebc23d 100644 --- a/tests/deser_test +++ b/tests/deser_test @@ -67,6 +67,12 @@ fn cbor_bytes_sz(bytes: Vec, sz: cbor_event::StringLenSz) -> Vec { buf.finalize() } +fn cbor_float(f: f64) -> Vec { + let mut buf = cbor_event::se::Serializer::new_vec(); + buf.write_special(cbor_event::Special::Float(f)).unwrap(); + buf.finalize() +} + fn print_cbor_types(obj_name: &str, vec: &Vec) { use cbor_event::Type; let mut raw = cbor_event::de::Deserializer::from(std::io::Cursor::new(vec)); diff --git a/tests/preserve-encodings/input.cddl b/tests/preserve-encodings/input.cddl index b4cbb20..354b125 100644 --- a/tests/preserve-encodings/input.cddl +++ b/tests/preserve-encodings/input.cddl @@ -76,4 +76,16 @@ default_uint = uint .default 1337 map_with_defaults = { ? 1 : default_uint ? 2 : text .default "two" -} \ No newline at end of file +} + +; TODO: preserve-encodings remembering optional fixed values. Issue: https://github.com/dcSpark/cddl-codegen/issues/205 +array_opt_fields = [ +; ? x: null, + ? a: uint, + ? b: text, + c: nint, + ? d: text, + y: #6.10(1), + ? e: non_overlapping_type_choice_some +; ? z: null, +] \ No newline at end of file diff --git a/tests/preserve-encodings/tests.rs b/tests/preserve-encodings/tests.rs index 7f5016d..056db1b 100644 --- a/tests/preserve-encodings/tests.rs +++ b/tests/preserve-encodings/tests.rs @@ -644,4 +644,75 @@ mod tests { } } } + + #[test] + fn array_opt_fields() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + let str_12_encodings = vec![ + StringLenSz::Len(Sz::One), + StringLenSz::Len(Sz::Inline), + StringLenSz::Indefinite(vec![(5, Sz::Two), (7, Sz::One)]), + StringLenSz::Indefinite(vec![(3, Sz::Inline), (0, Sz::Inline), (9, Sz::Four)]), + ]; + for str_enc in &str_12_encodings { + for def_enc in &def_encodings { + let e_values = [ + None, + Some(NonOverlappingTypeChoiceSome::U64 { + uint: 5, + uint_encoding: Some(*def_enc), + }), + Some(NonOverlappingTypeChoiceSome::N64 { + n64: 4, + n64_encoding: Some(*def_enc), + }), + Some(NonOverlappingTypeChoiceSome::Text { + text: "twelve chars".to_owned(), + text_encoding: str_enc.clone().into(), + }), + ]; + for e in &e_values { + for a in [false, true] { + for b in [false, true] { + for d in [false, true] { + // TODO: preserve-encodings remembering optional fixed values. Issue: https://github.com/dcSpark/cddl-codegen/issues/205 + // for x in [false, true] { + // for z in [false, true] { + let mut components: Vec> = vec![vec![ARR_INDEF]]; + // if x { + // components.push(vec![0xf5]); + // } + if a { + components.push(cbor_int(0, *def_enc)); + } + if b { + components.push(cbor_str_sz("hello, world", str_enc.clone())); + } + // c + components.push(cbor_int(-10, *def_enc)); + if d { + components.push(cbor_str_sz("cddl-codegen", str_enc.clone())); + } + // y + components.push(cbor_tag_sz(10, *def_enc)); + components.push(cbor_int(1, *def_enc)); + if let Some(e) = &e { + components.push(e.to_cbor_bytes()); + } + // if z { + // //components.push(vec![NULL]); + // } + components.push(vec![BREAK]); + let irregular_bytes = components.into_iter().flatten().clone().collect::>(); + let irregular = ArrayOptFields::from_cbor_bytes(&irregular_bytes).unwrap(); + assert_eq!(irregular_bytes, irregular.to_cbor_bytes()); + // } + // } + } + } + } + } + } + } + } }