Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialize for array groups with optional fields #204

Merged
merged 2 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 184 additions & 79 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,11 @@ impl GenerationScope {
rust_struct.variant(),
RustStructType::Array { .. } | RustStructType::Table { .. }
);
// the is_referenced check is for things like Int which are included by default
// The is_referenced check is for things like Int which are included by default
// in order for the CDDL to parse but might not be used.
if !is_typedef && types.is_referenced(rust_ident) {
// However, we need to export other root types from the user's spec
if !is_typedef && (rust_ident.as_ref() != "Int" || types.is_referenced(rust_ident))
{
main_lines_by_file
.entry(types.scope(rust_ident).clone())
.or_default()
Expand Down Expand Up @@ -2042,6 +2044,10 @@ impl GenerationScope {
_ => unimplemented!(),
};
deser_code.throws = true;
// this block needs to evaluate to a Result even though it has no value
if !cli.preserve_encodings && before_after.expects_result {
deser_code.content.line("Ok(())");
}
}
SerializingRustType::Root(ConceptualRustType::Primitive(p)) => {
if config.optional_field {
Expand Down Expand Up @@ -4260,6 +4266,10 @@ fn generate_array_struct_serialization(
for field in record.fields.iter() {
let field_expr = format!("{}{}", opt_self, field.name);
if field.optional {
if field.rust_type.is_fixed_value() && !cli.preserve_encodings {
// we just want to skip this entirely if we aren't remembering enecodings
continue;
}
let (optional_field_check, field_expr, expr_is_ref) = if let Some(default_value) =
&field.rust_type.default
{
Expand Down Expand Up @@ -4354,73 +4364,171 @@ fn generate_array_struct_deserialization(
let mut deser_code = DeserializationCode::default();
let mut deser_ctor_fields = vec![];
let mut encoding_struct_ctor_fields = vec![];
for field in record.fields.iter() {
if field.optional {
gen_scope.dont_generate_deserialize(
name,
format!(
"Array with optional field {}: {}",
field.name,
field.rust_type.for_rust_member(types, false, cli)
),
);
for (field_index, field) in record.fields.iter().enumerate() {
let (before, after) = if cli.preserve_encodings {
let var_names_str = encoding_var_names_str(types, &field.name, &field.rust_type, cli);
if cli.annotate_fields {
(
Cow::from(format!("let {var_names_str} = ")),
Cow::from("?;"),
)
} else {
(Cow::from(format!("let {var_names_str} = ")), Cow::from(";"))
}
} else if field.rust_type.is_fixed_value() {
// don't set anything, only verify data
if cli.annotate_fields {
(Cow::from(""), Cow::from("?;"))
} else {
(Cow::from(""), Cow::from(""))
}
} else if cli.annotate_fields {
(Cow::from(format!("let {} = ", field.name)), Cow::from("?;"))
} else {
if cli.preserve_encodings {
let var_names_str =
encoding_var_names_str(types, &field.name, &field.rust_type, cli);
if cli.annotate_fields {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", true),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.annotate(&field.name, &format!("let {var_names_str} = "), "?;")
.add_to_code(&mut deser_code);
(Cow::from(format!("let {} = ", field.name)), Cow::from(";"))
};
if field.optional {
// we can support optional fields, but only when they're immediately non-ambiguous
// i.e. when the next type (possibly skipping subsequent optional fields)
// is different from the current type.
// Supporting the general case 100% is extremely complicated without a combinatorial
// backtrack but for most sane real-world cases this wouldn't be necessary.
// Think purposefully written edge-cases with multiple optional fields, possibly nested
// in other structs, and with many of the same types.
// e.g. [ ? uint, uint, ? (uint, text), ? text]
let field_cbor_types = field.rust_type.cbor_types(types);
let mut possibly_last_field = true;
for i in (field_index + 1)..record.fields.len() {
if record.fields[i]
.rust_type
.cbor_types(types)
.iter()
.any(|ct| field_cbor_types.contains(ct))
{
gen_scope.dont_generate_deserialize(
name,
format!(
"Array struct with potentially-ambiguous optional field {}: {:?}",
field.name, field.rust_type,
),
);
}
if !record.fields[i].optional {
if i < record.fields.len() - 1 {
possibly_last_field = false;
}
break;
}
}
// we also need to be careful if we're possibly the last field in the CBOR
// buffer to avoid raw.cbor_type()? throwing an error for CBOR(NotEnough(0, 0))
let type_check_cond = if field_cbor_types.len() == 1 {
let type_str = cbor_type_code_str(field_cbor_types[0]);
if possibly_last_field {
// We also need to be careful if the last one is a non-Break special
// and the array is encoded using indefinite encoding.
// There's no nice way to access this as Deserializer::special_break() consumes
// the byte so we'll just inline this ugly code instead
if field_cbor_types.contains(&cbor_event::Type::Special) {
"if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| cbor_event::Type::from(*byte) == cbor_event::Type::Special && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)".to_owned()
} else {
format!("if raw.cbor_type().map(|ty| ty == {type_str}).unwrap_or(false)")
}
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new(
&format!("let {var_names_str} = "),
";",
false,
),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
format!("if raw.cbor_type()? == {type_str}")
}
} else {
let types_str = field_cbor_types
.iter()
.map(|ty| cbor_type_code_str(*ty))
.collect::<Vec<_>>()
.join(", ");
if possibly_last_field {
// We also need to be careful if the last one is a non-Break special
// and the array is encoded using indefinite encoding.
// There's no nice way to access this as Deserializer::special_break() consumes
// the byte so we'll just inline this ugly code instead
if field_cbor_types.contains(&cbor_event::Type::Special) {
format!(
"if raw.as_mut_ref().fill_buf().ok().and_then(|buf| buf.get(0)).map(|byte: &u8| vec![{types_str}].contains(&cbor_event::Type::from(*byte)) && (*byte & 0b0001_1111) != 0x1f).unwrap_or(false)",
)
.add_to_code(&mut deser_code);
} else {
format!("if raw.cbor_type().map(|ty| vec![{types_str}].contains(&ty)).unwrap_or(false)")
}
} else {
format!("if vec![{types_str}].contains(&raw.cbor_type()?)")
}
} else if field.rust_type.is_fixed_value() {
// don't set anything, only verify data
if cli.annotate_fields {
let mut err_deser = gen_scope.generate_deserialize(
};
let type_check_block = Block::new(format!("{before}{type_check_cond}"));
let mut type_check_else = Block::new("else");
if cli.annotate_fields {
let enc_fields = if cli.preserve_encodings {
let resolved_rust_type = field.rust_type.clone().resolve_aliases();
assert!(
!resolved_rust_type.is_fixed_value(),
"https://github.com/dcSpark/cddl-codegen/issues/205"
);
encoding_fields(types, &field.name, &resolved_rust_type, false, cli)
} else {
vec![]
};
let (some_map, defaults) = if !enc_fields.is_empty() {
let enc_names_str = enc_fields
.iter()
.map(|enc| enc.field_name.clone())
.collect::<Vec<String>>()
.join(", ");
(
Cow::from(format!(
"|({}, {})| (Some({}), {})",
field.name, enc_names_str, field.name, enc_names_str
)),
Cow::from(format!(
"(None, {})",
enc_fields
.iter()
.map(|enc| enc.default_expr.to_owned())
.collect::<Vec<String>>()
.join(", ")
)),
)
} else {
(Cow::from("Some"), Cow::from("None"))
};
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", true),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
DeserializeConfig::new(&field.name)
.in_embedded(in_embedded)
.optional_field(true),
cli,
);
// this block needs to evaluate to a Result even though it has no value
err_deser.content.line("Ok(())");
err_deser
.annotate(&field.name, "", "?;")
.add_to_code(&mut deser_code);
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("", "", false),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.add_to_code(&mut deser_code);
}
} else if cli.annotate_fields {
)
.annotate(&field.name, "", &format!(".map({some_map})"))
.wrap_in_block(type_check_block)
.add_to_code(&mut deser_code);
type_check_else.line(format!("Ok({defaults})"));
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new("Some(", ")", false),
DeserializeConfig::new(&field.name)
.in_embedded(in_embedded)
.optional_field(true),
cli,
)
.wrap_in_block(type_check_block)
.add_to_code(&mut deser_code);
type_check_else.line("None");
}
type_check_else.after(after);
deser_code.content.push_block(type_check_else);
} else {
// mandatory fields
if cli.annotate_fields {
gen_scope
.generate_deserialize(
types,
Expand All @@ -4429,22 +4537,22 @@ fn generate_array_struct_deserialization(
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.annotate(&field.name, &format!("let {} = ", field.name), "?;")
.annotate(&field.name, before.as_ref(), after.as_ref())
.add_to_code(&mut deser_code);
} else {
gen_scope
.generate_deserialize(
types,
(&field.rust_type).into(),
DeserializeBeforeAfter::new(&format!("let {} = ", field.name), ";", false),
DeserializeBeforeAfter::new(before.as_ref(), after.as_ref(), false),
DeserializeConfig::new(&field.name).in_embedded(in_embedded),
cli,
)
.add_to_code(&mut deser_code);
}
if !field.rust_type.is_fixed_value() {
deser_ctor_fields.push((field.name.clone(), field.name.clone()));
}
}
if !field.rust_type.is_fixed_value() {
deser_ctor_fields.push((field.name.clone(), field.name.clone()));
}
}
if cli.preserve_encodings {
Expand All @@ -4459,18 +4567,15 @@ fn generate_array_struct_deserialization(
encoding_vars_output.push(("tag_encoding".to_owned(), "Some(tag_encoding)".to_owned()));
}
for field in record.fields.iter() {
// we don't support deserialization for optional fields so don't even bother
if !field.optional {
for field_enc in encoding_fields(
types,
&field.name,
&field.rust_type.clone().resolve_aliases(),
true,
cli,
) {
encoding_vars_output
.push((field_enc.field_name.clone(), field_enc.field_name.clone()));
}
for field_enc in encoding_fields(
types,
&field.name,
&field.rust_type.clone().resolve_aliases(),
true,
cli,
) {
encoding_vars_output
.push((field_enc.field_name.clone(), field_enc.field_name.clone()));
}
}
}
Expand Down Expand Up @@ -4607,7 +4712,7 @@ fn codegen_struct(
gen_scope.dont_generate_deserialize(
name,
format!(
"field {}: {} couldn't generate serialize",
"field {}: {} couldn't generate deserialize",
field.name,
field.rust_type.for_rust_member(types, false, cli)
),
Expand Down
4 changes: 4 additions & 0 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,10 @@ impl RustRecord {
let mut conditional_field_expr = String::new();
for field in &self.fields {
if field.optional {
if !cli.preserve_encodings && field.rust_type.is_fixed_value() {
// we don't create fields for fixed values when preserve-encodings=false
continue;
}
if !conditional_field_expr.is_empty() {
conditional_field_expr.push_str(" + ");
}
Expand Down
13 changes: 12 additions & 1 deletion tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,15 @@ overlapping0 = [0]
overlapping1 = [0, uint]
overlapping2 = [0, uint, text]

overlapping = overlapping0 / overlapping1 / overlapping2
overlapping = overlapping0 / overlapping1 / overlapping2

array_opt_fields = [
? x: 1.010101
? a: uint,
? b: text,
c: nint,
? d: text,
y: 3.14159265
? e: non_overlapping_type_choice_some
? z: 2.71828
]
Loading
Loading