Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve MRSolver interface #1675

Merged
merged 7 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,11 @@ ppTermInMonCtx :: MonadifyCtx -> Term -> String
ppTermInMonCtx ctx t =
scPrettyTermInCtx defaultPPOpts (map (\(x,_,_) -> x) ctx) t

-- | A memoization table for monadifying terms
type MonadifyMemoTable = IntMap MonTerm
-- | A memoization table for monadifying terms: a map from 'TermIndex'es to
-- 'MonTerm's and, possibly, corresponding 'ArgMonTerm's. The latter are simply
-- the result of calling 'argifyMonTerm' on the former, but are only added when
-- needed (i.e. when 'memoArgMonTerm' is called, e.g. in 'monadifyArg').
type MonadifyMemoTable = IntMap (MonTerm, Maybe ArgMonTerm)

-- | The empty memoization table
emptyMemoTable :: MonadifyMemoTable
Expand Down Expand Up @@ -752,15 +755,34 @@ runCompleteMonadifyM sc env top_ret_tp m =
runMonadifyM env [] (toArgType $ monadifyType [] top_ret_tp) m

-- | Memoize a computation of the monadified term associated with a 'TermIndex'
memoizingM :: TermIndex -> MonadifyM MonTerm -> MonadifyM MonTerm
memoizingM i m =
memoMonTerm :: TermIndex -> MonadifyM MonTerm -> MonadifyM MonTerm
memoMonTerm i m =
(IntMap.lookup i <$> get) >>= \case
Just ret ->
return ret
Just (mtm, _) ->
return mtm
Nothing ->
do ret <- m
modify (IntMap.insert i ret)
return ret
do mtm <- m
modify (IntMap.insert i (mtm, Nothing))
return mtm

-- | Memoize a computation of the monadified term of argument type associated
-- with a 'TermIndex', using a memoized 'ArgTerm' directly if it exists or
-- applying 'argifyMonTerm' to a memoized 'MonTerm' (and memoizing the result)
-- if it exists
memoArgMonTerm :: TermIndex -> MonadifyM MonTerm -> MonadifyM ArgMonTerm
memoArgMonTerm i m =
(IntMap.lookup i <$> get) >>= \case
Just (_, Just argmtm) ->
return argmtm
Just (mtm, Nothing) ->
do argmtm <- argifyMonTerm mtm
modify (IntMap.insert i (mtm, Just argmtm))
return argmtm
Nothing ->
do mtm <- m
argmtm <- argifyMonTerm mtm
modify (IntMap.insert i (mtm, Just argmtm))
return argmtm

-- | Turn a 'MonTerm' of type @CompMT(tp)@ to a term of argument type @MT(tp)@
-- by inserting a monadic bind if the 'MonTerm' is computational
Expand Down Expand Up @@ -799,7 +821,15 @@ monadifyTypeM tp =

-- | Monadify a term to a monadified term of argument type
monadifyArg :: Maybe MonType -> Term -> MonadifyM ArgMonTerm
monadifyArg mtp t = monadifyTerm mtp t >>= argifyMonTerm
{-
monadifyArg _ t
| trace ("Monadifying term of argument type: " ++ showTerm t) False
= undefined
-}
monadifyArg mtp t@(STApp { stAppIndex = ix }) =
memoArgMonTerm ix $ monadifyTerm' mtp t
monadifyArg mtp t =
monadifyTerm' mtp t >>= argifyMonTerm

-- | Monadify a term to argument type and convert back to a term
monadifyArgTerm :: Maybe MonType -> Term -> MonadifyM OpenTerm
Expand All @@ -813,7 +843,7 @@ monadifyTerm _ t
= undefined
-}
monadifyTerm mtp t@(STApp { stAppIndex = ix }) =
memoizingM ix $ monadifyTerm' mtp t
memoMonTerm ix $ monadifyTerm' mtp t
monadifyTerm mtp t =
monadifyTerm' mtp t

Expand Down
22 changes: 11 additions & 11 deletions examples/mr_solver/mr_solver_unit_tests.saw
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ let run_test name test expected =
do { if expected then print (str_concat "Test: " name) else
print (str_concat (str_concat "Test: " name) " (expecting failure)");
actual <- test;
if eq_bool actual expected then print "Success\n" else
if eq_bool actual expected then print "Test passed\n" else
do { print "Test failed\n"; exit 1; }; };

// The constant 0 function const0 x = 0
Expand All @@ -21,19 +21,19 @@ const0 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 0)";
const1 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 1)";

// const0 <= const0
run_test "mr_solver const0 const0" (mr_solver const0 const0) true;
run_test "const0 |= const0" (mr_solver_query const0 const0) true;

// The function test_fun0 from the prelude = const0
test_fun0 <- parse_core "test_fun0";
run_test "mr_solver const0 test_fun0" (mr_solver const0 test_fun0) true;
run_test "const0 |= test_fun0" (mr_solver_query const0 test_fun0) true;

// not const0 <= const1
run_test "mr_solver const0 const1" (mr_solver const0 const1) false;
run_test "const0 |= const1" (mr_solver_query const0 const1) false;

// The function test_fun1 from the prelude = const1
test_fun1 <- parse_core "test_fun1";
run_test "mr_solver const1 test_fun1" (mr_solver const1 test_fun1) true;
run_test "mr_solver const0 test_fun1" (mr_solver const0 test_fun1) false;
run_test "const1 |= test_fun1" (mr_solver_query const1 test_fun1) true;
run_test "const0 |= test_fun1" (mr_solver_query const0 test_fun1) false;

// ifxEq0 x = If x == 0 then x else 0; should be equal to 0
ifxEq0 <- parse_core "\\ (x:Vec 64 Bool) -> \
Expand All @@ -42,21 +42,21 @@ ifxEq0 <- parse_core "\\ (x:Vec 64 Bool) -> \
\ (returnM (Vec 64 Bool) (bvNat 64 0))";

// ifxEq0 <= const0
run_test "mr_solver ifxEq0 const0" (mr_solver ifxEq0 const0) true;
run_test "ifxEq0 |= const0" (mr_solver_query ifxEq0 const0) true;

// not ifxEq0 <= const1
run_test "mr_solver ifxEq0 const1" (mr_solver ifxEq0 const1) false;
run_test "ifxEq0 |= const1" (mr_solver_query ifxEq0 const1) false;

// noErrors1 x = exists x. returnM x
noErrors1 <- parse_core "\\ (x:Vec 64 Bool) -> \
\ existsM (Vec 64 Bool) (Vec 64 Bool) \
\ (\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool) x)";

// const0 <= noErrors
run_test "mr_solver noErrors1 noErrors1" (mr_solver noErrors1 noErrors1) true;
run_test "noErrors1 |= noErrors1" (mr_solver_query noErrors1 noErrors1) true;

// const1 <= noErrors
run_test "mr_solver const1 noErrors1" (mr_solver const1 noErrors1) true;
run_test "const1 |= noErrors1" (mr_solver_query const1 noErrors1) true;

// noErrorsRec1 x = orM (existsM x. returnM x) (noErrorsRec1 x)
// Intuitively, this specifies functions that either return a value or loop
Expand All @@ -74,4 +74,4 @@ loop1 <- parse_core
\ (\\ (f: Vec 64 Bool -> CompM (Vec 64 Bool)) (x:Vec 64 Bool) -> f x)";

// loop1 <= noErrorsRec1
run_test "mr_solver loop1 noErrorsRec1" (mr_solver loop1 noErrorsRec1) true;
run_test "loop1 |= noErrorsRec1" (mr_solver_query loop1 noErrorsRec1) true;
24 changes: 4 additions & 20 deletions heapster-saw/examples/arrays_mr_solver.saw
Original file line number Diff line number Diff line change
@@ -1,31 +1,15 @@
include "arrays.saw";

let eq_bool b1 b2 =
if b1 then
if b2 then true else false
else
if b2 then false else true;

let fail = do { print "Test failed"; exit 1; };
let run_test name test expected =
do { if expected then print (str_concat "Test: " name) else
print (str_concat (str_concat "Test: " name) " (expecting failure)");
actual <- test;
if eq_bool actual expected then print "Success\n" else
do { print "Test failed\n"; exit 1; }; };

// Test that contains0 |= contains0
contains0 <- parse_core_mod "arrays" "contains0";
// run_test "contains0 |= contains0" (mr_solver contains0 contains0) true;
mr_solver_test contains0 contains0;

noErrorsContains0 <- parse_core_mod "arrays" "noErrorsContains0";
run_test "contains0 |= noErrorsContains0"
(mr_solver_debug 0 contains0 noErrorsContains0) true;
mr_solver_prove contains0 noErrorsContains0;

include "specPrims.saw";
import "arrays.cry";

zero_array <- parse_core_mod "arrays" "zero_array";
run_test "zero_array |= zero_array_spec"
// (mr_solver_debug 0 zero_array {{ zero_array_loop_spec }}) true;
(mr_solver_debug 0 zero_array {{ zero_array_spec }}) true;
// mr_solver_prove zero_array {{ zero_array_loop_spec }};
mr_solver_prove zero_array {{ zero_array_spec }};
18 changes: 1 addition & 17 deletions heapster-saw/examples/exp_explosion_mr_solver.saw
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
include "exp_explosion.saw";

let eq_bool b1 b2 =
if b1 then
if b2 then true else false
else
if b2 then false else true;

let fail = do { print "Test failed"; exit 1; };
let run_test name test expected =
do { if expected then print (str_concat "Test: " name) else
print (str_concat (str_concat "Test: " name) " (expecting failure)");
actual <- test;
if eq_bool actual expected then print "Success\n" else
do { print "Test failed\n"; exit 1; }; };



import "exp_explosion.cry";
monadify_term {{ op }};

exp_explosion <- parse_core_mod "exp_explosion" "exp_explosion";
run_test "exp_explosion |= exp_explosion_spec" (mr_solver exp_explosion {{ exp_explosion_spec }}) true;
mr_solver_prove exp_explosion {{ exp_explosion_spec }};
35 changes: 6 additions & 29 deletions heapster-saw/examples/linked_list_mr_solver.saw
Original file line number Diff line number Diff line change
@@ -1,24 +1,5 @@
include "linked_list.saw";

/***
*** Testing infrastructure
***/

let eq_bool b1 b2 =
if b1 then
if b2 then true else false
else
if b2 then false else true;

let fail = do { print "Test failed"; exit 1; };
let run_test name test expected =
do { if expected then print (str_concat "Test: " name) else
print (str_concat (str_concat "Test: " name) " (expecting failure)");
actual <- test;
if eq_bool actual expected then print "Success\n" else
do { print "Test failed\n"; exit 1; }; };


/***
*** Setup Cryptol environment
***/
Expand All @@ -45,15 +26,13 @@ heapster_typecheck_fun env "is_head"
"(). arg0:int64<>, arg1:List<int64<>,always,R> -o \
\ arg0:true, arg1:true, ret:int64<>";

/*
is_head <- parse_core_mod "linked_list" "is_head";
run_test "is_head |= is_head" (mr_solver is_head is_head) true;
*/
mr_solver_test is_head is_head;

is_elem <- parse_core_mod "linked_list" "is_elem";
// run_test "is_elem |= is_elem" (mr_solver_debug 0 is_elem is_elem) true;

/*
mr_solver_test is_elem is_elem;

is_elem_noErrorsSpec <- parse_core
"\\ (x:Vec 64 Bool) (y:List (Vec 64 Bool)) -> \
\ fixM (Vec 64 Bool * List (Vec 64 Bool)) \
Expand All @@ -63,10 +42,9 @@ is_elem_noErrorsSpec <- parse_core
\ orM (Vec 64 Bool) \
\ (existsM (Vec 64 Bool) (Vec 64 Bool) (returnM (Vec 64 Bool))) \
\ (rec x)) (x, y)";
run_test "is_elem |= noErrorsSpec" (mr_solver is_elem is_elem_noErrorsSpec) true;
*/
mr_solver_test is_elem is_elem_noErrorsSpec;

run_test "is_elem |= is_elem_spec" (mr_solver is_elem {{ is_elem_spec }}) true;
mr_solver_prove is_elem {{ is_elem_spec }};


monadify_term {{ Right }};
Expand All @@ -75,5 +53,4 @@ monadify_term {{ nil }};
monadify_term {{ cons }};

sorted_insert_no_malloc <- parse_core_mod "linked_list" "sorted_insert_no_malloc";
run_test "sorted_insert_no_malloc |= sorted_insert_spec"
(mr_solver sorted_insert_no_malloc {{ sorted_insert_spec }}) true;
mr_solver_prove sorted_insert_no_malloc {{ sorted_insert_spec }};
12 changes: 6 additions & 6 deletions heapster-saw/examples/sha512_mr_solver.saw
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ processBlock <- parse_core_mod "SHA512" "processBlock";
processBlocks <- parse_core_mod "SHA512" "processBlocks";

// Test that every function refines itself
// run_test "processBlocks |= processBlocks" (mr_solver processBlocks processBlocks) true;
// run_test "processBlock |= processBlock" (mr_solver processBlock processBlock) true;
// run_test "round_16_80 |= round_16_80" (mr_solver round_16_80 round_16_80) true;
// run_test "round_00_15 |= round_00_15" (mr_solver round_00_15 round_00_15) true;
// mr_solver_test processBlocks processBlocks;
// mr_solver_test processBlock processBlock;
// mr_solver_test round_16_80 round_16_80;
// mr_solver_test round_00_15 round_00_15;

import "sha512.cry";
// FIXME: Why aren't we monadifying these automatically when they're used?
Expand All @@ -105,5 +105,5 @@ monadify_term {{ Maj }};
// "round_16_80 |= round_16_80_spec"?
monadify_term {{ round_00_15_spec }};

run_test "round_00_15 |= round_00_15_spec" (mr_solver round_00_15 {{ round_00_15_spec }}) true;
run_test "round_16_80 |= round_16_80_spec" (mr_solver round_16_80 {{ round_16_80_spec }}) true;
mr_solver_prove round_00_15 {{ round_00_15_spec }};
mr_solver_prove round_16_80 {{ round_16_80_spec }};
Loading