Skip to content

Commit

Permalink
Deserialize for array groups with optional fields
Browse files Browse the repository at this point in the history
Fixes #203

Fixes #154 as this issue popped up while making sure optional fields
were very well supported.
  • Loading branch information
rooooooooob committed Jul 27, 2023
1 parent 0469793 commit 623218c
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 80 deletions.
232 changes: 153 additions & 79 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,11 @@ impl GenerationScope {
rust_struct.variant(),
RustStructType::Array { .. } | RustStructType::Table { .. }
);
// the is_referenced check is for things like Int which are included by default
// The is_referenced check is for things like Int which are included by default
// in order for the CDDL to parse but might not be used.
if !is_typedef && types.is_referenced(rust_ident) {
// However, we need to export other root types from the user's spec
if !is_typedef && (rust_ident.as_ref() != "Int" || types.is_referenced(rust_ident))
{
main_lines_by_file
.entry(types.scope(rust_ident).clone())
.or_default()
Expand Down Expand Up @@ -2042,6 +2044,10 @@ impl GenerationScope {
_ => unimplemented!(),
};
deser_code.throws = true;
// this block needs to evaluate to a Result even though it has no value
if !cli.preserve_encodings && before_after.expects_result {
deser_code.content.line("Ok(())");
}
}
SerializingRustType::Root(ConceptualRustType::Primitive(p)) => {
if config.optional_field {
Expand Down Expand Up @@ -4260,6 +4266,10 @@ fn generate_array_struct_serialization(
for field in record.fields.iter() {
let field_expr = format!("{}{}", opt_self, field.name);
if field.optional {
if field.rust_type.is_fixed_value() && !cli.preserve_encodings {
// we just want to skip this entirely if we aren't remembering enecodings
continue;
}
let (optional_field_check, field_expr, expr_is_ref) = if let Some(default_value) =
&field.rust_type.default
{
Expand Down Expand Up @@ -4354,73 +4364,139 @@ fn generate_array_struct_deserialization(
let mut deser_code = DeserializationCode::default();
let mut deser_ctor_fields = vec![];
let mut encoding_struct_ctor_fields = vec![];
for field in record.fields.iter() {
if field.optional {
gen_scope.dont_generate_deserialize(
name,
format!(
"Array with optional field {}: {}",
field.name,
field.rust_type.for_rust_member(types, false, cli)
),
);
for (field_index, field) in record.fields.iter().enumerate() {
let (before, after) = if cli.preserve_encodings {
let var_names_str = encoding_var_names_str(types, &field.name, &field.rust_type, cli);
if cli.annotate_fields {
(
Cow::from(format!("let {var_names_str} = ")),
Cow::from("?;"),
)
} else {
(Cow::from(format!("let {var_names_str} = ")), Cow::from(";"))
}
} else if field.rust_type.is_fixed_value() {
// don't set anything, only verify data
if cli.annotate_fields {
(Cow::from(""), Cow::from("?;"))
} else {
(Cow::from(""), Cow::from(""))
}
} else if cli.annotate_fields {
(Cow::from(format!("let {} = ", field.name)), Cow::from("?;"))
} else {
if cli.preserve_encodings {
let var_names_str =
encoding_var_names_str(types, &field.name, &field.rust_type, cli);
if cli.annotate_fields {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", true),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.annotate(&field.name, &format!("let {var_names_str} = "), "?;")
.add_to_code(&mut deser_code);
(Cow::from(format!("let {} = ", field.name)), Cow::from(";"))
};
if field.optional {
// we can support optional fields, but only when they're immediately non-ambiguous
// i.e. when the next type (possibly skipping subsequent optional fields)
// is different from the current type.
// Supporting the general case 100% is extremely complicated without a combinatorial
// backtrack but for most sane real-world cases this wouldn't be necessary.
// Think purposefully written edge-cases with multiple optional fields, possibly nested
// in other structs, and with many of the same types.
// e.g. [ ? uint, uint, ? (uint, text), ? text]
let field_cbor_types = field.rust_type.cbor_types(types);
let mut possibly_last_field = true;
for i in (field_index + 1)..record.fields.len() {
if record.fields[i]
.rust_type
.cbor_types(types)
.iter()
.any(|ct| field_cbor_types.contains(ct))
{
gen_scope.dont_generate_deserialize(
name,
format!(
"Array struct with potentially-ambiguous optional field {}: {}",
field.name,
field.rust_type.for_rust_member(types, false, cli)
),
);
}
if !record.fields[i].optional {
if i < record.fields.len() - 1 {
possibly_last_field = false;
}
break;
}
}
// we also need to be careful if we're possibly the last field in the CBOR
// buffer to avoid raw.cbor_type()? throwing an error for CBOR(NotEnough(0, 0))
let type_check_cond = if field_cbor_types.len() == 1 {
let type_str = cbor_type_code_str(field_cbor_types[0]);
if possibly_last_field {
// We also need to be careful if the last one is a non-Break special
// and the array is encoded using indefinite encoding.
// There's no nice way to access this as Deserializer::special_break() consumes
// the byte so we'll just inline this ugly code instead
if field_cbor_types.contains(&cbor_event::Type::Special) {
"if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| cbor_event::Type::from(*byte) == cbor_event::Type::Special && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)".to_owned()
} else {
format!("if raw.cbor_type().map(|ty| ty == {type_str}).unwrap_or(false)")
}
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new(
&format!("let {var_names_str} = "),
";",
false,
),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
format!("if raw.cbor_type()? == {type_str}")
}
} else {
let types_str = field_cbor_types
.iter()
.map(|ty| cbor_type_code_str(*ty))
.collect::<Vec<_>>()
.join(", ");
if possibly_last_field {
// We also need to be careful if the last one is a non-Break special
// and the array is encoded using indefinite encoding.
// There's no nice way to access this as Deserializer::special_break() consumes
// the byte so we'll just inline this ugly code instead
if field_cbor_types.contains(&cbor_event::Type::Special) {
format!(
"if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| vec![{types_str}].contains(&cbor_event::Type::from(*byte)) && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)",
)
.add_to_code(&mut deser_code);
} else {
format!("if raw.cbor_type().map(|ty| vec![{types_str}].contains(&ty)).unwrap_or(false)")
}
} else {
format!("if vec![{types_str}].contains(&raw.cbor_type()?)")
}
} else if field.rust_type.is_fixed_value() {
// don't set anything, only verify data
if cli.annotate_fields {
let mut err_deser = gen_scope.generate_deserialize(
};
let type_check_block = Block::new(format!("{before}{type_check_cond}"));
let mut type_check_else = Block::new("else");
if cli.annotate_fields {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", true),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
DeserializeConfig::new(&field.name)
.in_embedded(in_embedded)
.optional_field(true),
cli,
);
// this block needs to evaluate to a Result even though it has no value
err_deser.content.line("Ok(())");
err_deser
.annotate(&field.name, "", "?;")
.add_to_code(&mut deser_code);
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", false),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.add_to_code(&mut deser_code);
}
} else if cli.annotate_fields {
)
.annotate(&field.name, "", ".map(Some)")
.wrap_in_block(type_check_block)
.add_to_code(&mut deser_code);
type_check_else.line("Ok(None)");
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("Some(", ")", false),
DeserializeConfig::new(&field.name)
.in_embedded(in_embedded)
.optional_field(true),
cli,
)
.wrap_in_block(type_check_block)
.add_to_code(&mut deser_code);
type_check_else.line("None");
}
type_check_else.after(after);
deser_code.content.push_block(type_check_else);
} else {
// mandatory fields
if cli.annotate_fields {
gen_scope
.generate_deserialize(
types,
Expand All @@ -4429,22 +4505,22 @@ fn generate_array_struct_deserialization(
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.annotate(&field.name, &format!("let {} = ", field.name), "?;")
.annotate(&field.name, before.as_ref(), after.as_ref())
.add_to_code(&mut deser_code);
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new(&format!("let {} = ", field.name), ";", false),
DeserializeBeforeAfter::new(before.as_ref(), after.as_ref(), false),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.add_to_code(&mut deser_code);
}
if !field.rust_type.is_fixed_value() {
deser_ctor_fields.push((field.name.clone(), field.name.clone()));
}
}
if !field.rust_type.is_fixed_value() {
deser_ctor_fields.push((field.name.clone(), field.name.clone()));
}
}
if cli.preserve_encodings {
Expand Down Expand Up @@ -4607,7 +4683,7 @@ fn codegen_struct(
gen_scope.dont_generate_deserialize(
name,
format!(
"field {}: {} couldn't generate serialize",
"field {}: {} couldn't generate deserialize",
field.name,
field.rust_type.for_rust_member(types, false, cli)
),
Expand Down Expand Up @@ -6302,19 +6378,17 @@ fn generate_enum(
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
} else if names_without_outer.is_empty() {
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
} else {
if names_without_outer.is_empty() {
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
}
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
}
variant_deser_code
}
Expand Down
4 changes: 4 additions & 0 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,10 @@ impl RustRecord {
let mut conditional_field_expr = String::new();
for field in &self.fields {
if field.optional {
if !cli.preserve_encodings && field.rust_type.is_fixed_value() {
// we don't create fields for fixed values when preserve-encodings=false
continue;
}
if !conditional_field_expr.is_empty() {
conditional_field_expr.push_str(" + ");
}
Expand Down
13 changes: 12 additions & 1 deletion tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,15 @@ overlapping0 = [0]
overlapping1 = [0, uint]
overlapping2 = [0, uint, text]

overlapping = overlapping0 / overlapping1 / overlapping2
overlapping = overlapping0 / overlapping1 / overlapping2

array_opt_fields = [
? x: 1.010101
? a: uint,
? b: text,
c: nint,
? d: text,
y: 3.14159265
? e: non_overlapping_type_choice_some
? z: 2.71828
]
55 changes: 55 additions & 0 deletions tests/core/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,59 @@ mod tests {
deser_test(&NonOverlappingTypeChoiceSome::N64(10000));
deser_test(&NonOverlappingTypeChoiceSome::Text("Hello, World!".into()));
}

#[test]
fn array_opt_fields() {
let mut foo = ArrayOptFields::new(10);
for e in [None, Some(NonOverlappingTypeChoiceSome::U64(5)), Some(NonOverlappingTypeChoiceSome::N64(4)), Some(NonOverlappingTypeChoiceSome::Text("five".to_owned()))] {
for a in [false, true] {
for b in [false, true] {
for d in [false, true] {
// round-trip on non-constants
foo.a = if a { Some(0) } else { None };
foo.b = if b { Some("hello, world".to_owned()) } else { None };
foo.d = if d { Some("cddl-codegen".to_owned()) } else { None };
foo.e = e.clone();
deser_test(&foo);
// deser for constants too
for x in [false, true] {
for y in [false, true] {
for z in [false, true] {
let mut components = vec![vec![ARR_INDEF]];
let bytes = vec![
vec![ARR_INDEF]
];
if x {
components.push(cbor_float(1.010101));
}
if a {
components.push(cbor_int(0, cbor_event::Sz::One));
}
if b {
components.push(cbor_string("hello, world"));
}
// c
components.push(cbor_int(-10, cbor_event::Sz::One));
if d {
components.push(cbor_string("cddl-codegen"));
}
// y
components.push(cbor_float(3.14159265));
if let Some(e) = &e {
components.push(e.to_cbor_bytes());
}
if z {
components.push(cbor_float(2.71828));
}
components.push(vec![BREAK]);
let bytes = components.into_iter().flatten().clone().collect::<Vec<u8>>();
let _ = ArrayOptFields::from_cbor_bytes(&bytes).unwrap();
}
}
}
}
}
}
}
}
}
Loading

0 comments on commit 623218c

Please sign in to comment.