Skip to content

Commit

Permalink
Enable an #[safety_constraint(...)] attribute helper for the `Arbit…
Browse files Browse the repository at this point in the history
…rary` and `Invariant` macros (#3283)

This PR enables an `#[safety_constraint(...)]` attribute helper for the
`#[derive(Arbitrary)]` and `#[derive(Invariant)]` macro.

For the `Invariant` derive macro, this allows users to derive more
sophisticated invariants for their data types by annotating individual
named fields with the `#[safety_constraint(<cond>)]` attribute, where
`<cond>` represents the predicate to be evaluated for the corresponding
field. In addition, the implementation always checks `#field.is_safe()`
for each field.

For example, let's say we are working with the `Point` type from #3250
```rs
#[derive(kani::Invariant)]
struct Point<X, Y> {
    x: X,
    y: Y,
}
```

and we need to extend it to only allow positive values for both `x` and
`y`.
With the `[safety_constraint(...)]` attribute, we can achieve this
without explicitly writing the `impl Invariant for ...` as follows:

```rs
#[derive(kani::Invariant)]
struct PositivePoint {
    #[safety_constraint(*x >= 0)]
    x: i32,
    #[safety_constraint(*y >= 0)]
    y: i32,
}
```

For the `Arbitrary` derive macro, this allows users to derive more
sophisticated `kani::any()` generators that respect the specified
invariants. In other words, the `kani::any()` will assume any invariants
specified through the `#[safety_constraint(...)]` attribute helper.
Going back to the `PositivePoint` example, we'd expect this harness to
be successful:

```rs
#[kani::proof]
fn check_invariant_helper_ok() {
    let pos_point: PositivePoint = kani::any();
    assert!(pos_point.x >= 0);
    assert!(pos_point.y >= 0);
}
```

The PR includes tests to ensure it's working as expected, in addition to
UI tests checking for cases where the arguments provided to the macro
are incorrect. Happy to add any other cases that you feel are missing.

Related #3095
  • Loading branch information
adpaco-aws committed Jul 22, 2024
1 parent 5d6bf69 commit 7ad4d1c
Show file tree
Hide file tree
Showing 24 changed files with 668 additions and 20 deletions.
202 changes: 186 additions & 16 deletions library/kani_macros/src/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,30 @@ pub fn expand_derive_arbitrary(item: proc_macro::TokenStream) -> proc_macro::Tok
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let body = fn_any_body(&item_name, &derive_item.data);
let expanded = quote! {
// The generated implementation.
impl #impl_generics kani::Arbitrary for #item_name #ty_generics #where_clause {
fn any() -> Self {
#body

// Get the safety constraints (if any) to produce type-safe values
let safety_conds_opt = safety_conds(&item_name, &derive_item.data);

let expanded = if let Some(safety_cond) = safety_conds_opt {
let field_refs = field_refs(&item_name, &derive_item.data);
quote! {
// The generated implementation.
impl #impl_generics kani::Arbitrary for #item_name #ty_generics #where_clause {
fn any() -> Self {
let obj = #body;
#field_refs
kani::assume(#safety_cond);
obj
}
}
}
} else {
quote! {
// The generated implementation.
impl #impl_generics kani::Arbitrary for #item_name #ty_generics #where_clause {
fn any() -> Self {
#body
}
}
}
};
Expand Down Expand Up @@ -75,6 +94,103 @@ fn fn_any_body(ident: &Ident, data: &Data) -> TokenStream {
}
}

/// Parse the condition expressions in `#[safety_constraint(<cond>)]` attached to struct
/// fields and, it at least one was found, generate a conjunction to be assumed.
///
/// For example, if we're deriving implementations for the struct
/// ```
/// #[derive(Arbitrary)]
/// #[derive(Invariant)]
/// struct PositivePoint {
/// #[safety_constraint(*x >= 0)]
/// x: i32,
/// #[safety_constraint(*y >= 0)]
/// y: i32,
/// }
/// ```
/// this function will generate the `TokenStream`
/// ```
/// *x >= 0 && *y >= 0
/// ```
/// which can be passed to `kani::assume` to constrain the values generated
/// through the `Arbitrary` impl so that they are type-safe by construction.
fn safety_conds(ident: &Ident, data: &Data) -> Option<TokenStream> {
match data {
Data::Struct(struct_data) => safety_conds_inner(ident, &struct_data.fields),
Data::Enum(_) => None,
Data::Union(_) => None,
}
}

/// Generates an expression resulting from the conjunction of conditions
/// specified as safety constraints for each field. See `safety_conds` for more details.
fn safety_conds_inner(ident: &Ident, fields: &Fields) -> Option<TokenStream> {
match fields {
Fields::Named(ref fields) => {
let conds: Vec<TokenStream> =
fields.named.iter().filter_map(|field| parse_safety_expr(ident, field)).collect();
if !conds.is_empty() { Some(quote! { #(#conds)&&* }) } else { None }
}
Fields::Unnamed(_) => None,
Fields::Unit => None,
}
}

/// Generates the sequence of expressions to initialize the variables used as
/// references to the struct fields.
///
/// For example, if we're deriving implementations for the struct
/// ```
/// #[derive(Arbitrary)]
/// #[derive(Invariant)]
/// struct PositivePoint {
/// #[safety_constraint(*x >= 0)]
/// x: i32,
/// #[safety_constraint(*y >= 0)]
/// y: i32,
/// }
/// ```
/// this function will generate the `TokenStream`
/// ```
/// let x = &obj.x;
/// let y = &obj.y;
/// ```
/// which allows us to refer to the struct fields without using `self`.
/// Note that the actual stream is generated in the `field_refs_inner` function.
fn field_refs(ident: &Ident, data: &Data) -> TokenStream {
match data {
Data::Struct(struct_data) => field_refs_inner(ident, &struct_data.fields),
Data::Enum(_) => unreachable!(),
Data::Union(_) => unreachable!(),
}
}

/// Generates the sequence of expressions to initialize the variables used as
/// references to the struct fields. See `field_refs` for more details.
fn field_refs_inner(_ident: &Ident, fields: &Fields) -> TokenStream {
match fields {
Fields::Named(ref fields) => {
let field_refs: Vec<TokenStream> = fields
.named
.iter()
.map(|field| {
let name = &field.ident;
quote_spanned! {field.span()=>
let #name = &obj.#name;
}
})
.collect();
if !field_refs.is_empty() {
quote! { #( #field_refs )* }
} else {
quote! {}
}
}
Fields::Unnamed(_) => quote! {},
Fields::Unit => quote! {},
}
}

/// Generate an item initialization where an item can be a struct or a variant.
/// For named fields, this will generate: `Item { field1: kani::any(), field2: kani::any(), .. }`
/// For unnamed fields, this will generate: `Item (kani::any(), kani::any(), ..)`
Expand Down Expand Up @@ -115,6 +231,42 @@ fn init_symbolic_item(ident: &Ident, fields: &Fields) -> TokenStream {
}
}

/// Extract, parse and return the expression `cond` (i.e., `Some(cond)`) in the
/// `#[safety_constraint(<cond>)]` attribute helper associated with a given field.
/// Return `None` if the attribute isn't specified.
fn parse_safety_expr(ident: &Ident, field: &syn::Field) -> Option<TokenStream> {
let name = &field.ident;
let mut safety_helper_attr = None;

// Keep the helper attribute if we find it
for attr in &field.attrs {
if attr.path().is_ident("safety_constraint") {
safety_helper_attr = Some(attr);
}
}

// Parse the arguments in the `#[safety_constraint(...)]` attribute
if let Some(attr) = safety_helper_attr {
let expr_args: Result<syn::Expr, syn::Error> = attr.parse_args();

// Check if there was an error parsing the arguments
if let Err(err) = expr_args {
abort!(Span::call_site(), "Cannot derive impl for `{}`", ident;
note = attr.span() =>
"safety constraint in field `{}` could not be parsed: {}", name.as_ref().unwrap().to_string(), err
)
}

// Return the expression associated to the safety constraint
let safety_expr = expr_args.unwrap();
Some(quote_spanned! {field.span()=>
#safety_expr
})
} else {
None
}
}

/// Generate the body of the function `any()` for enums. The cases are:
/// 1. For zero-variants enumerations, this will encode a `panic!()` statement.
/// 2. For one or more variants, the code will be something like:
Expand Down Expand Up @@ -176,10 +328,14 @@ pub fn expand_derive_invariant(item: proc_macro::TokenStream) -> proc_macro::Tok
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let body = is_safe_body(&item_name, &derive_item.data);
let field_refs = field_refs(&item_name, &derive_item.data);

let expanded = quote! {
// The generated implementation.
impl #impl_generics kani::Invariant for #item_name #ty_generics #where_clause {
fn is_safe(&self) -> bool {
let obj = self;
#field_refs
#body
}
}
Expand All @@ -199,7 +355,7 @@ fn add_trait_bound_invariant(mut generics: Generics) -> Generics {

fn is_safe_body(ident: &Ident, data: &Data) -> TokenStream {
match data {
Data::Struct(struct_data) => struct_safe_conjunction(ident, &struct_data.fields),
Data::Struct(struct_data) => struct_invariant_conjunction(ident, &struct_data.fields),
Data::Enum(_) => {
abort!(Span::call_site(), "Cannot derive `Invariant` for `{}` enum", ident;
note = ident.span() =>
Expand All @@ -215,21 +371,35 @@ fn is_safe_body(ident: &Ident, data: &Data) -> TokenStream {
}
}

/// Generates an expression that is the conjunction of `is_safe` calls for each field in the struct.
fn struct_safe_conjunction(_ident: &Ident, fields: &Fields) -> TokenStream {
/// Generates an expression that is the conjunction of safety constraints for each field in the struct.
fn struct_invariant_conjunction(ident: &Ident, fields: &Fields) -> TokenStream {
match fields {
// Expands to the expression
// `true && <safety_cond1> && <safety_cond2> && ..`
// where `safety_condN` is
// - `self.fieldN.is_safe() && <cond>` if a condition `<cond>` was
// specified through the `#[safety_constraint(<cond>)]` helper attribute, or
// - `self.fieldN.is_safe()` otherwise
//
// Therefore, if `#[safety_constraint(<cond>)]` isn't specified for any field, this expands to
// `true && self.field1.is_safe() && self.field2.is_safe() && ..`
Fields::Named(ref fields) => {
let safe_calls = fields.named.iter().map(|field| {
let name = &field.ident;
quote_spanned! {field.span()=>
self.#name.is_safe()
}
});
let safety_conds: Vec<TokenStream> = fields
.named
.iter()
.map(|field| {
let name = &field.ident;
let default_expr = quote_spanned! {field.span()=>
#name.is_safe()
};
parse_safety_expr(ident, field)
.map(|expr| quote! { #expr && #default_expr})
.unwrap_or(default_expr)
})
.collect();
// An initial value is required for empty structs
safe_calls.fold(quote! { true }, |acc, call| {
quote! { #acc && #call }
safety_conds.iter().fold(quote! { true }, |acc, cond| {
quote! { #acc && #cond }
})
}
Fields::Unnamed(ref fields) => {
Expand Down
112 changes: 108 additions & 4 deletions library/kani_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,120 @@ pub fn unstable_feature(attr: TokenStream, item: TokenStream) -> TokenStream {
attr_impl::unstable(attr, item)
}

/// Allow users to auto generate Arbitrary implementations by using `#[derive(Arbitrary)]` macro.
/// Allow users to auto generate `Arbitrary` implementations by using
/// `#[derive(Arbitrary)]` macro.
///
/// When using `#[derive(Arbitrary)]` on a struct, the `#[safety_constraint(<cond>)]`
/// attribute can be added to its fields to indicate a type safety invariant
/// condition `<cond>`. Since `kani::any()` is always expected to produce
/// type-safe values, **adding `#[safety_constraint(...)]` to any fields will further
/// constrain the objects generated with `kani::any()`**.
///
/// For example, the `check_positive` harness in this code is expected to
/// pass:
///
/// ```rs
/// #[derive(kani::Arbitrary)]
/// struct AlwaysPositive {
/// #[safety_constraint(*inner >= 0)]
/// inner: i32,
/// }
///
/// #[kani::proof]
/// fn check_positive() {
/// let val: AlwaysPositive = kani::any();
/// assert!(val.inner >= 0);
/// }
/// ```
///
/// Therefore, using the `#[safety_constraint(...)]` attribute can lead to vacuous
/// results when the values are over-constrained. For example, in this code
/// the `check_positive` harness will pass too:
///
/// ```rs
/// #[derive(kani::Arbitrary)]
/// struct AlwaysPositive {
/// #[safety_constraint(*inner >= 0 && *inner < i32::MIN)]
/// inner: i32,
/// }
///
/// #[kani::proof]
/// fn check_positive() {
/// let val: AlwaysPositive = kani::any();
/// assert!(val.inner >= 0);
/// }
/// ```
///
/// Unfortunately, we made a mistake when specifying the condition because
/// `*inner >= 0 && *inner < i32::MIN` is equivalent to `false`. This results
/// in the relevant assertion being unreachable:
///
/// ```
/// Check 1: check_positive.assertion.1
/// - Status: UNREACHABLE
/// - Description: "assertion failed: val.inner >= 0"
/// - Location: src/main.rs:22:5 in function check_positive
/// ```
///
/// As usual, we recommend users to defend against these behaviors by using
/// `kani::cover!(...)` checks and watching out for unreachable assertions in
/// their project's code.
#[proc_macro_error]
#[proc_macro_derive(Arbitrary)]
#[proc_macro_derive(Arbitrary, attributes(safety_constraint))]
pub fn derive_arbitrary(item: TokenStream) -> TokenStream {
derive::expand_derive_arbitrary(item)
}

/// Allow users to auto generate Invariant implementations by using `#[derive(Invariant)]` macro.
/// Allow users to auto generate `Invariant` implementations by using
/// `#[derive(Invariant)]` macro.
///
/// When using `#[derive(Invariant)]` on a struct, the `#[safety_constraint(<cond>)]`
/// attribute can be added to its fields to indicate a type safety invariant
/// condition `<cond>`. This will ensure that the gets additionally checked when
/// using the `is_safe()` method generated by the `#[derive(Invariant)]` macro.
///
/// For example, the `check_positive` harness in this code is expected to
/// fail:
///
/// ```rs
/// #[derive(kani::Invariant)]
/// struct AlwaysPositive {
/// #[safety_constraint(*inner >= 0)]
/// inner: i32,
/// }
///
/// #[kani::proof]
/// fn check_positive() {
/// let val = AlwaysPositive { inner: -1 };
/// assert!(val.is_safe());
/// }
/// ```
///
/// This is not too surprising since the type safety invariant that we indicated
/// is not being taken into account when we create the `AlwaysPositive` object.
///
/// As mentioned, the `is_safe()` methods generated by the
/// `#[derive(Invariant)]` macro check the corresponding `is_safe()` method for
/// each field in addition to any type safety invariants specified through the
/// `#[safety_constraint(...)]` attribute.
///
/// For example, for the `AlwaysPositive` struct from above, we will generate
/// the following implementation:
///
/// ```rs
/// impl kani::Invariant for AlwaysPositive {
/// fn is_safe(&self) -> bool {
/// let obj = self;
/// let inner = &obj.inner;
/// true && *inner >= 0 && inner.is_safe()
/// }
/// }
/// ```
///
/// Note: the assignments to `obj` and `inner` are made so that we can treat the
/// fields as if they were references.
#[proc_macro_error]
#[proc_macro_derive(Invariant)]
#[proc_macro_derive(Invariant, attributes(safety_constraint))]
pub fn derive_invariant(item: TokenStream) -> TokenStream {
derive::expand_derive_invariant(item)
}
Expand Down
17 changes: 17 additions & 0 deletions tests/expected/derive-arbitrary/safety_constraint_helper/expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Check 1: check_invariant_helper_ok.assertion.1\
- Status: SUCCESS\
- Description: "assertion failed: pos_point.x >= 0"

Check 2: check_invariant_helper_ok.assertion.2\
- Status: SUCCESS\
- Description: "assertion failed: pos_point.y >= 0"

Check 1: check_invariant_helper_fail.assertion.1\
- Status: FAILURE\
- Description: "assertion failed: pos_point.x >= 0"

Check 2: check_invariant_helper_fail.assertion.2\
- Status: FAILURE\
- Description: "assertion failed: pos_point.y >= 0"

Complete - 2 successfully verified harnesses, 0 failures, 2 total.
Loading

0 comments on commit 7ad4d1c

Please sign in to comment.