Skip to content

Commit

Permalink
Deserialize for array groups with optional fields (#204)
Browse files Browse the repository at this point in the history
* Deserialize for array groups with optional fields

Fixes #203

Fixes #154 as this issue popped up while making sure optional fields
were very well supported.

* Optional array fields with preserve-encodings=true
  • Loading branch information
rooooooooob authored Aug 9, 2023
1 parent 719b9ab commit d6086cf
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 81 deletions.
263 changes: 184 additions & 79 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,11 @@ impl GenerationScope {
rust_struct.variant(),
RustStructType::Array { .. } | RustStructType::Table { .. }
);
// the is_referenced check is for things like Int which are included by default
// The is_referenced check is for things like Int which are included by default
// in order for the CDDL to parse but might not be used.
if !is_typedef && types.is_referenced(rust_ident) {
// However, we need to export other root types from the user's spec
if !is_typedef && (rust_ident.as_ref() != "Int" || types.is_referenced(rust_ident))
{
main_lines_by_file
.entry(types.scope(rust_ident).clone())
.or_default()
Expand Down Expand Up @@ -2042,6 +2044,10 @@ impl GenerationScope {
_ => unimplemented!(),
};
deser_code.throws = true;
// this block needs to evaluate to a Result even though it has no value
if !cli.preserve_encodings && before_after.expects_result {
deser_code.content.line("Ok(())");
}
}
SerializingRustType::Root(ConceptualRustType::Primitive(p)) => {
if config.optional_field {
Expand Down Expand Up @@ -4260,6 +4266,10 @@ fn generate_array_struct_serialization(
for field in record.fields.iter() {
let field_expr = format!("{}{}", opt_self, field.name);
if field.optional {
if field.rust_type.is_fixed_value() && !cli.preserve_encodings {
// we just want to skip this entirely if we aren't remembering enecodings
continue;
}
let (optional_field_check, field_expr, expr_is_ref) = if let Some(default_value) =
&field.rust_type.default
{
Expand Down Expand Up @@ -4354,73 +4364,171 @@ fn generate_array_struct_deserialization(
let mut deser_code = DeserializationCode::default();
let mut deser_ctor_fields = vec![];
let mut encoding_struct_ctor_fields = vec![];
for field in record.fields.iter() {
if field.optional {
gen_scope.dont_generate_deserialize(
name,
format!(
"Array with optional field {}: {}",
field.name,
field.rust_type.for_rust_member(types, false, cli)
),
);
for (field_index, field) in record.fields.iter().enumerate() {
let (before, after) = if cli.preserve_encodings {
let var_names_str = encoding_var_names_str(types, &field.name, &field.rust_type, cli);
if cli.annotate_fields {
(
Cow::from(format!("let {var_names_str} = ")),
Cow::from("?;"),
)
} else {
(Cow::from(format!("let {var_names_str} = ")), Cow::from(";"))
}
} else if field.rust_type.is_fixed_value() {
// don't set anything, only verify data
if cli.annotate_fields {
(Cow::from(""), Cow::from("?;"))
} else {
(Cow::from(""), Cow::from(""))
}
} else if cli.annotate_fields {
(Cow::from(format!("let {} = ", field.name)), Cow::from("?;"))
} else {
if cli.preserve_encodings {
let var_names_str =
encoding_var_names_str(types, &field.name, &field.rust_type, cli);
if cli.annotate_fields {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", true),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.annotate(&field.name, &format!("let {var_names_str} = "), "?;")
.add_to_code(&mut deser_code);
(Cow::from(format!("let {} = ", field.name)), Cow::from(";"))
};
if field.optional {
// we can support optional fields, but only when they're immediately non-ambiguous
// i.e. when the next type (possibly skipping subsequent optional fields)
// is different from the current type.
// Supporting the general case 100% is extremely complicated without a combinatorial
// backtrack but for most sane real-world cases this wouldn't be necessary.
// Think purposefully written edge-cases with multiple optional fields, possibly nested
// in other structs, and with many of the same types.
// e.g. [ ? uint, uint, ? (uint, text), ? text]
let field_cbor_types = field.rust_type.cbor_types(types);
let mut possibly_last_field = true;
for i in (field_index + 1)..record.fields.len() {
if record.fields[i]
.rust_type
.cbor_types(types)
.iter()
.any(|ct| field_cbor_types.contains(ct))
{
gen_scope.dont_generate_deserialize(
name,
format!(
"Array struct with potentially-ambiguous optional field {}: {:?}",
field.name, field.rust_type,
),
);
}
if !record.fields[i].optional {
if i < record.fields.len() - 1 {
possibly_last_field = false;
}
break;
}
}
// we also need to be careful if we're possibly the last field in the CBOR
// buffer to avoid raw.cbor_type()? throwing an error for CBOR(NotEnough(0, 0))
let type_check_cond = if field_cbor_types.len() == 1 {
let type_str = cbor_type_code_str(field_cbor_types[0]);
if possibly_last_field {
// We also need to be careful if the last one is a non-Break special
// and the array is encoded using indefinite encoding.
// There's no nice way to access this as Deserializer::special_break() consumes
// the byte so we'll just inline this ugly code instead
if field_cbor_types.contains(&cbor_event::Type::Special) {
"if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| cbor_event::Type::from(*byte) == cbor_event::Type::Special && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)".to_owned()
} else {
format!("if raw.cbor_type().map(|ty| ty == {type_str}).unwrap_or(false)")
}
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new(
&format!("let {var_names_str} = "),
";",
false,
),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
format!("if raw.cbor_type()? == {type_str}")
}
} else {
let types_str = field_cbor_types
.iter()
.map(|ty| cbor_type_code_str(*ty))
.collect::<Vec<_>>()
.join(", ");
if possibly_last_field {
// We also need to be careful if the last one is a non-Break special
// and the array is encoded using indefinite encoding.
// There's no nice way to access this as Deserializer::special_break() consumes
// the byte so we'll just inline this ugly code instead
if field_cbor_types.contains(&cbor_event::Type::Special) {
format!(
"if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| vec![{types_str}].contains(&cbor_event::Type::from(*byte)) && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)",
)
.add_to_code(&mut deser_code);
} else {
format!("if raw.cbor_type().map(|ty| vec![{types_str}].contains(&ty)).unwrap_or(false)")
}
} else {
format!("if vec![{types_str}].contains(&raw.cbor_type()?)")
}
} else if field.rust_type.is_fixed_value() {
// don't set anything, only verify data
if cli.annotate_fields {
let mut err_deser = gen_scope.generate_deserialize(
};
let type_check_block = Block::new(format!("{before}{type_check_cond}"));
let mut type_check_else = Block::new("else");
if cli.annotate_fields {
let enc_fields = if cli.preserve_encodings {
let resolved_rust_type = field.rust_type.clone().resolve_aliases();
assert!(
!resolved_rust_type.is_fixed_value(),
"https://github.com/dcSpark/cddl-codegen/issues/205"
);
encoding_fields(types, &field.name, &resolved_rust_type, false, cli)
} else {
vec![]
};
let (some_map, defaults) = if !enc_fields.is_empty() {
let enc_names_str = enc_fields
.iter()
.map(|enc| enc.field_name.clone())
.collect::<Vec<String>>()
.join(", ");
(
Cow::from(format!(
"|({}, {})| (Some({}), {})",
field.name, enc_names_str, field.name, enc_names_str
)),
Cow::from(format!(
"(None, {})",
enc_fields
.iter()
.map(|enc| enc.default_expr.to_owned())
.collect::<Vec<String>>()
.join(", ")
)),
)
} else {
(Cow::from("Some"), Cow::from("None"))
};
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", true),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
DeserializeConfig::new(&field.name)
.in_embedded(in_embedded)
.optional_field(true),
cli,
);
// this block needs to evaluate to a Result even though it has no value
err_deser.content.line("Ok(())");
err_deser
.annotate(&field.name, "", "?;")
.add_to_code(&mut deser_code);
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", false),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.add_to_code(&mut deser_code);
}
} else if cli.annotate_fields {
)
.annotate(&field.name, "", &format!(".map({some_map})"))
.wrap_in_block(type_check_block)
.add_to_code(&mut deser_code);
type_check_else.line(format!("Ok({defaults})"));
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("Some(", ")", false),
DeserializeConfig::new(&field.name)
.in_embedded(in_embedded)
.optional_field(true),
cli,
)
.wrap_in_block(type_check_block)
.add_to_code(&mut deser_code);
type_check_else.line("None");
}
type_check_else.after(after);
deser_code.content.push_block(type_check_else);
} else {
// mandatory fields
if cli.annotate_fields {
gen_scope
.generate_deserialize(
types,
Expand All @@ -4429,22 +4537,22 @@ fn generate_array_struct_deserialization(
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.annotate(&field.name, &format!("let {} = ", field.name), "?;")
.annotate(&field.name, before.as_ref(), after.as_ref())
.add_to_code(&mut deser_code);
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new(&format!("let {} = ", field.name), ";", false),
DeserializeBeforeAfter::new(before.as_ref(), after.as_ref(), false),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.add_to_code(&mut deser_code);
}
if !field.rust_type.is_fixed_value() {
deser_ctor_fields.push((field.name.clone(), field.name.clone()));
}
}
if !field.rust_type.is_fixed_value() {
deser_ctor_fields.push((field.name.clone(), field.name.clone()));
}
}
if cli.preserve_encodings {
Expand All @@ -4459,18 +4567,15 @@ fn generate_array_struct_deserialization(
encoding_vars_output.push(("tag_encoding".to_owned(), "Some(tag_encoding)".to_owned()));
}
for field in record.fields.iter() {
// we don't support deserialization for optional fields so don't even bother
if !field.optional {
for field_enc in encoding_fields(
types,
&field.name,
&field.rust_type.clone().resolve_aliases(),
true,
cli,
) {
encoding_vars_output
.push((field_enc.field_name.clone(), field_enc.field_name.clone()));
}
for field_enc in encoding_fields(
types,
&field.name,
&field.rust_type.clone().resolve_aliases(),
true,
cli,
) {
encoding_vars_output
.push((field_enc.field_name.clone(), field_enc.field_name.clone()));
}
}
}
Expand Down Expand Up @@ -4607,7 +4712,7 @@ fn codegen_struct(
gen_scope.dont_generate_deserialize(
name,
format!(
"field {}: {} couldn't generate serialize",
"field {}: {} couldn't generate deserialize",
field.name,
field.rust_type.for_rust_member(types, false, cli)
),
Expand Down
4 changes: 4 additions & 0 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,10 @@ impl RustRecord {
let mut conditional_field_expr = String::new();
for field in &self.fields {
if field.optional {
if !cli.preserve_encodings && field.rust_type.is_fixed_value() {
// we don't create fields for fixed values when preserve-encodings=false
continue;
}
if !conditional_field_expr.is_empty() {
conditional_field_expr.push_str(" + ");
}
Expand Down
13 changes: 12 additions & 1 deletion tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,15 @@ overlapping0 = [0]
overlapping1 = [0, uint]
overlapping2 = [0, uint, text]

overlapping = overlapping0 / overlapping1 / overlapping2
overlapping = overlapping0 / overlapping1 / overlapping2

array_opt_fields = [
? x: 1.010101
? a: uint,
? b: text,
c: nint,
? d: text,
y: 3.14159265
? e: non_overlapping_type_choice_some
? z: 2.71828
]
Loading

0 comments on commit d6086cf

Please sign in to comment.