diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs index 53656dd6f7..76de3b0318 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs @@ -681,8 +681,11 @@ ppTermInMonCtx :: MonadifyCtx -> Term -> String ppTermInMonCtx ctx t = scPrettyTermInCtx defaultPPOpts (map (\(x,_,_) -> x) ctx) t --- | A memoization table for monadifying terms -type MonadifyMemoTable = IntMap MonTerm +-- | A memoization table for monadifying terms: a map from 'TermIndex'es to +-- 'MonTerm's and, possibly, corresponding 'ArgMonTerm's. The latter are simply +-- the result of calling 'argifyMonTerm' on the former, but are only added when +-- needed (i.e. when 'memoArgMonTerm' is called, e.g. in 'monadifyArg'). +type MonadifyMemoTable = IntMap (MonTerm, Maybe ArgMonTerm) -- | The empty memoization table emptyMemoTable :: MonadifyMemoTable @@ -752,15 +755,34 @@ runCompleteMonadifyM sc env top_ret_tp m = runMonadifyM env [] (toArgType $ monadifyType [] top_ret_tp) m -- | Memoize a computation of the monadified term associated with a 'TermIndex' -memoizingM :: TermIndex -> MonadifyM MonTerm -> MonadifyM MonTerm -memoizingM i m = +memoMonTerm :: TermIndex -> MonadifyM MonTerm -> MonadifyM MonTerm +memoMonTerm i m = (IntMap.lookup i <$> get) >>= \case - Just ret -> - return ret + Just (mtm, _) -> + return mtm Nothing -> - do ret <- m - modify (IntMap.insert i ret) - return ret + do mtm <- m + modify (IntMap.insert i (mtm, Nothing)) + return mtm + +-- | Memoize a computation of the monadified term of argument type associated +-- with a 'TermIndex', using a memoized 'ArgTerm' directly if it exists or +-- applying 'argifyMonTerm' to a memoized 'MonTerm' (and memoizing the result) +-- if it exists +memoArgMonTerm :: TermIndex -> MonadifyM MonTerm -> MonadifyM ArgMonTerm +memoArgMonTerm i m = + (IntMap.lookup i <$> get) >>= \case + Just (_, Just argmtm) -> + return argmtm + Just (mtm, Nothing) -> + do argmtm <- argifyMonTerm mtm + modify (IntMap.insert i (mtm, Just argmtm)) + return argmtm + Nothing -> + do mtm <- m + argmtm <- argifyMonTerm mtm + modify (IntMap.insert i (mtm, Just argmtm)) + return argmtm -- | Turn a 'MonTerm' of type @CompMT(tp)@ to a term of argument type @MT(tp)@ -- by inserting a monadic bind if the 'MonTerm' is computational @@ -799,7 +821,15 @@ monadifyTypeM tp = -- | Monadify a term to a monadified term of argument type monadifyArg :: Maybe MonType -> Term -> MonadifyM ArgMonTerm -monadifyArg mtp t = monadifyTerm mtp t >>= argifyMonTerm +{- +monadifyArg _ t + | trace ("Monadifying term of argument type: " ++ showTerm t) False + = undefined +-} +monadifyArg mtp t@(STApp { stAppIndex = ix }) = + memoArgMonTerm ix $ monadifyTerm' mtp t +monadifyArg mtp t = + monadifyTerm' mtp t >>= argifyMonTerm -- | Monadify a term to argument type and convert back to a term monadifyArgTerm :: Maybe MonType -> Term -> MonadifyM OpenTerm @@ -813,7 +843,7 @@ monadifyTerm _ t = undefined -} monadifyTerm mtp t@(STApp { stAppIndex = ix }) = - memoizingM ix $ monadifyTerm' mtp t + memoMonTerm ix $ monadifyTerm' mtp t monadifyTerm mtp t = monadifyTerm' mtp t diff --git a/examples/mr_solver/mr_solver_unit_tests.saw b/examples/mr_solver/mr_solver_unit_tests.saw index 5366cdbab2..704ca9807c 100644 --- a/examples/mr_solver/mr_solver_unit_tests.saw +++ b/examples/mr_solver/mr_solver_unit_tests.saw @@ -11,7 +11,7 @@ let run_test name test expected = do { if expected then print (str_concat "Test: " name) else print (str_concat (str_concat "Test: " name) " (expecting failure)"); actual <- test; - if eq_bool actual expected then print "Success\n" else + if eq_bool actual expected then print "Test passed\n" else do { print "Test failed\n"; exit 1; }; }; // The constant 0 function const0 x = 0 @@ -21,19 +21,19 @@ const0 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 0)"; const1 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 1)"; // const0 <= const0 -run_test "mr_solver const0 const0" (mr_solver const0 const0) true; +run_test "const0 |= const0" (mr_solver_query const0 const0) true; // The function test_fun0 from the prelude = const0 test_fun0 <- parse_core "test_fun0"; -run_test "mr_solver const0 test_fun0" (mr_solver const0 test_fun0) true; +run_test "const0 |= test_fun0" (mr_solver_query const0 test_fun0) true; // not const0 <= const1 -run_test "mr_solver const0 const1" (mr_solver const0 const1) false; +run_test "const0 |= const1" (mr_solver_query const0 const1) false; // The function test_fun1 from the prelude = const1 test_fun1 <- parse_core "test_fun1"; -run_test "mr_solver const1 test_fun1" (mr_solver const1 test_fun1) true; -run_test "mr_solver const0 test_fun1" (mr_solver const0 test_fun1) false; +run_test "const1 |= test_fun1" (mr_solver_query const1 test_fun1) true; +run_test "const0 |= test_fun1" (mr_solver_query const0 test_fun1) false; // ifxEq0 x = If x == 0 then x else 0; should be equal to 0 ifxEq0 <- parse_core "\\ (x:Vec 64 Bool) -> \ @@ -42,10 +42,10 @@ ifxEq0 <- parse_core "\\ (x:Vec 64 Bool) -> \ \ (returnM (Vec 64 Bool) (bvNat 64 0))"; // ifxEq0 <= const0 -run_test "mr_solver ifxEq0 const0" (mr_solver ifxEq0 const0) true; +run_test "ifxEq0 |= const0" (mr_solver_query ifxEq0 const0) true; // not ifxEq0 <= const1 -run_test "mr_solver ifxEq0 const1" (mr_solver ifxEq0 const1) false; +run_test "ifxEq0 |= const1" (mr_solver_query ifxEq0 const1) false; // noErrors1 x = exists x. returnM x noErrors1 <- parse_core "\\ (x:Vec 64 Bool) -> \ @@ -53,10 +53,10 @@ noErrors1 <- parse_core "\\ (x:Vec 64 Bool) -> \ \ (\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool) x)"; // const0 <= noErrors -run_test "mr_solver noErrors1 noErrors1" (mr_solver noErrors1 noErrors1) true; +run_test "noErrors1 |= noErrors1" (mr_solver_query noErrors1 noErrors1) true; // const1 <= noErrors -run_test "mr_solver const1 noErrors1" (mr_solver const1 noErrors1) true; +run_test "const1 |= noErrors1" (mr_solver_query const1 noErrors1) true; // noErrorsRec1 x = orM (existsM x. returnM x) (noErrorsRec1 x) // Intuitively, this specifies functions that either return a value or loop @@ -74,4 +74,4 @@ loop1 <- parse_core \ (\\ (f: Vec 64 Bool -> CompM (Vec 64 Bool)) (x:Vec 64 Bool) -> f x)"; // loop1 <= noErrorsRec1 -run_test "mr_solver loop1 noErrorsRec1" (mr_solver loop1 noErrorsRec1) true; +run_test "loop1 |= noErrorsRec1" (mr_solver_query loop1 noErrorsRec1) true; diff --git a/heapster-saw/examples/arrays_mr_solver.saw b/heapster-saw/examples/arrays_mr_solver.saw index 386c4f095a..5dfbb1fa9b 100644 --- a/heapster-saw/examples/arrays_mr_solver.saw +++ b/heapster-saw/examples/arrays_mr_solver.saw @@ -1,31 +1,15 @@ include "arrays.saw"; -let eq_bool b1 b2 = - if b1 then - if b2 then true else false - else - if b2 then false else true; - -let fail = do { print "Test failed"; exit 1; }; -let run_test name test expected = - do { if expected then print (str_concat "Test: " name) else - print (str_concat (str_concat "Test: " name) " (expecting failure)"); - actual <- test; - if eq_bool actual expected then print "Success\n" else - do { print "Test failed\n"; exit 1; }; }; - // Test that contains0 |= contains0 contains0 <- parse_core_mod "arrays" "contains0"; -// run_test "contains0 |= contains0" (mr_solver contains0 contains0) true; +mr_solver_test contains0 contains0; noErrorsContains0 <- parse_core_mod "arrays" "noErrorsContains0"; -run_test "contains0 |= noErrorsContains0" - (mr_solver_debug 0 contains0 noErrorsContains0) true; +mr_solver_prove contains0 noErrorsContains0; include "specPrims.saw"; import "arrays.cry"; zero_array <- parse_core_mod "arrays" "zero_array"; -run_test "zero_array |= zero_array_spec" -// (mr_solver_debug 0 zero_array {{ zero_array_loop_spec }}) true; - (mr_solver_debug 0 zero_array {{ zero_array_spec }}) true; +// mr_solver_prove zero_array {{ zero_array_loop_spec }}; +mr_solver_prove zero_array {{ zero_array_spec }}; diff --git a/heapster-saw/examples/exp_explosion_mr_solver.saw b/heapster-saw/examples/exp_explosion_mr_solver.saw index 03c97256c2..2bd71bb927 100644 --- a/heapster-saw/examples/exp_explosion_mr_solver.saw +++ b/heapster-saw/examples/exp_explosion_mr_solver.saw @@ -1,23 +1,7 @@ include "exp_explosion.saw"; -let eq_bool b1 b2 = - if b1 then - if b2 then true else false - else - if b2 then false else true; - -let fail = do { print "Test failed"; exit 1; }; -let run_test name test expected = - do { if expected then print (str_concat "Test: " name) else - print (str_concat (str_concat "Test: " name) " (expecting failure)"); - actual <- test; - if eq_bool actual expected then print "Success\n" else - do { print "Test failed\n"; exit 1; }; }; - - - import "exp_explosion.cry"; monadify_term {{ op }}; exp_explosion <- parse_core_mod "exp_explosion" "exp_explosion"; -run_test "exp_explosion |= exp_explosion_spec" (mr_solver exp_explosion {{ exp_explosion_spec }}) true; +mr_solver_prove exp_explosion {{ exp_explosion_spec }}; diff --git a/heapster-saw/examples/linked_list_mr_solver.saw b/heapster-saw/examples/linked_list_mr_solver.saw index 2bf75117d9..28cc8093bb 100644 --- a/heapster-saw/examples/linked_list_mr_solver.saw +++ b/heapster-saw/examples/linked_list_mr_solver.saw @@ -1,24 +1,5 @@ include "linked_list.saw"; -/*** - *** Testing infrastructure - ***/ - -let eq_bool b1 b2 = - if b1 then - if b2 then true else false - else - if b2 then false else true; - -let fail = do { print "Test failed"; exit 1; }; -let run_test name test expected = - do { if expected then print (str_concat "Test: " name) else - print (str_concat (str_concat "Test: " name) " (expecting failure)"); - actual <- test; - if eq_bool actual expected then print "Success\n" else - do { print "Test failed\n"; exit 1; }; }; - - /*** *** Setup Cryptol environment ***/ @@ -45,15 +26,13 @@ heapster_typecheck_fun env "is_head" "(). arg0:int64<>, arg1:List,always,R> -o \ \ arg0:true, arg1:true, ret:int64<>"; -/* is_head <- parse_core_mod "linked_list" "is_head"; -run_test "is_head |= is_head" (mr_solver is_head is_head) true; -*/ +mr_solver_test is_head is_head; is_elem <- parse_core_mod "linked_list" "is_elem"; -// run_test "is_elem |= is_elem" (mr_solver_debug 0 is_elem is_elem) true; -/* +mr_solver_test is_elem is_elem; + is_elem_noErrorsSpec <- parse_core "\\ (x:Vec 64 Bool) (y:List (Vec 64 Bool)) -> \ \ fixM (Vec 64 Bool * List (Vec 64 Bool)) \ @@ -63,10 +42,9 @@ is_elem_noErrorsSpec <- parse_core \ orM (Vec 64 Bool) \ \ (existsM (Vec 64 Bool) (Vec 64 Bool) (returnM (Vec 64 Bool))) \ \ (rec x)) (x, y)"; -run_test "is_elem |= noErrorsSpec" (mr_solver is_elem is_elem_noErrorsSpec) true; -*/ +mr_solver_test is_elem is_elem_noErrorsSpec; -run_test "is_elem |= is_elem_spec" (mr_solver is_elem {{ is_elem_spec }}) true; +mr_solver_prove is_elem {{ is_elem_spec }}; monadify_term {{ Right }}; @@ -75,5 +53,4 @@ monadify_term {{ nil }}; monadify_term {{ cons }}; sorted_insert_no_malloc <- parse_core_mod "linked_list" "sorted_insert_no_malloc"; -run_test "sorted_insert_no_malloc |= sorted_insert_spec" - (mr_solver sorted_insert_no_malloc {{ sorted_insert_spec }}) true; +mr_solver_prove sorted_insert_no_malloc {{ sorted_insert_spec }}; diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index 372a3f0731..928e7ab40f 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -86,10 +86,10 @@ processBlock <- parse_core_mod "SHA512" "processBlock"; processBlocks <- parse_core_mod "SHA512" "processBlocks"; // Test that every function refines itself -// run_test "processBlocks |= processBlocks" (mr_solver processBlocks processBlocks) true; -// run_test "processBlock |= processBlock" (mr_solver processBlock processBlock) true; -// run_test "round_16_80 |= round_16_80" (mr_solver round_16_80 round_16_80) true; -// run_test "round_00_15 |= round_00_15" (mr_solver round_00_15 round_00_15) true; +// mr_solver_test processBlocks processBlocks; +// mr_solver_test processBlock processBlock; +// mr_solver_test round_16_80 round_16_80; +// mr_solver_test round_00_15 round_00_15; import "sha512.cry"; // FIXME: Why aren't we monadifying these automatically when they're used? @@ -105,5 +105,5 @@ monadify_term {{ Maj }}; // "round_16_80 |= round_16_80_spec"? monadify_term {{ round_00_15_spec }}; -run_test "round_00_15 |= round_00_15_spec" (mr_solver round_00_15 {{ round_00_15_spec }}) true; -run_test "round_16_80 |= round_16_80_spec" (mr_solver round_16_80 {{ round_16_80_spec }}) true; +mr_solver_prove round_00_15 {{ round_00_15_spec }}; +mr_solver_prove round_16_80 {{ round_16_80_spec }}; diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 5d435c067c..d6404c0f35 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -71,6 +71,7 @@ import Verifier.SAW.FiniteValue import Verifier.SAW.SATQuery import Verifier.SAW.SCTypeCheck hiding (TypedTerm) import qualified Verifier.SAW.SCTypeCheck as TC (TypedTerm(..)) +import Verifier.SAW.Recognizer import Verifier.SAW.SharedTerm import Verifier.SAW.TypedTerm import qualified Verifier.SAW.Simulator.Concrete as Concrete @@ -79,6 +80,7 @@ import Verifier.SAW.Rewriter import Verifier.SAW.Testing.Random (prepareSATQuery, runManyTests) import Verifier.SAW.TypedAST import qualified Verifier.SAW.Simulator.TermModel as TM +import Verifier.SAW.Term.Pretty (SawDoc, renderSawDoc) import SAWScript.Position @@ -1654,18 +1656,100 @@ ensureMonadicTerm sc t False -> monadifyTypedTerm sc t ensureMonadicTerm sc t = monadifyTypedTerm sc t --- | Run Mr Solver with the given debug level to prove that the first term --- refines the second -mrSolver :: SharedContext -> Int -> TypedTerm -> TypedTerm -> TopLevel Bool -mrSolver sc dlvl t1 t2 = - do rw <- get - m1 <- ttTerm <$> ensureMonadicTerm sc t1 - m2 <- ttTerm <$> ensureMonadicTerm sc t2 - let env = rwMRSolverEnv rw - res <- liftIO $ Prover.askMRSolver sc dlvl env Nothing m1 m2 +-- | A wrapper for 'Prover.askMRSolver' from @MRSolver.hs@ which if the first +-- argument is @Just str@, prints out @str@ followed by an abridged version +-- of the refinement being asked +askMRSolver :: Maybe SawDoc -> SharedContext -> TypedTerm -> TypedTerm -> + TopLevel (NominalDiffTime, + Either Prover.MRFailure Prover.MRSolverResult) +askMRSolver printStr sc t1 t2 = + do env <- rwMRSolverEnv <$> get + m1 <- collapseEta <$> ttTerm <$> ensureMonadicTerm sc t1 + m2 <- collapseEta <$> ttTerm <$> ensureMonadicTerm sc t2 + case printStr of + Nothing -> return () + Just str -> printOutLnTop Info $ renderSawDoc defaultPPOpts $ + "[MRSolver] " <> str <> ": " <> ppTmHead m1 <> + " |= " <> ppTmHead m2 + time1 <- liftIO getCurrentTime + res <- io $ Prover.askMRSolver sc env Nothing m1 m2 + time2 <- liftIO getCurrentTime + return (diffUTCTime time2 time1, res) + where -- Turn a term of the form @\x1 ... xn -> f x1 ... xn@ into @f@ + collapseEta :: Term -> Term + collapseEta (asLambdaList -> (lamVars, + asApplyAll -> (t@(smallestFreeVar -> Nothing), + mapM asLocalVar -> Just argVars))) + | argVars == [(length lamVars - 1), (length lamVars - 2) .. 0] = t + collapseEta t = t + -- Pretty-print the name of the top-level function call, followed by + -- "..." if it is given any arguments, or just "..." if there is no + -- top-level call + ppTmHead :: Term -> SawDoc + ppTmHead (asLambdaList -> (_, + asApplyAll -> (t@( + Prover.asProjAll -> ( + Monadify.asTypedGlobalDef -> Just _, _)), args))) = + ppTerm defaultPPOpts t <> if length args > 0 then " ..." else "" + ppTmHead _ = "..." + +-- | Run Mr Solver to prove that the first term refines the second, adding +-- any relevant 'Prover.FunAssump's to the 'Prover.MREnv' if the first argument +-- is true and this can be done, or printing an error message and exiting if it +-- cannot. +mrSolverProve :: Bool -> SharedContext -> TypedTerm -> TypedTerm -> TopLevel () +mrSolverProve addToEnv sc t1 t2 = + do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get + let printStr = if addToEnv then "Proving" else "Testing" + (diff, res) <- askMRSolver (Just printStr) sc t1 t2 + case res of + Left err | dlvl == 0 -> + io (putStrLn $ Prover.showMRFailure err) >> + printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> + io (Exit.exitWith $ Exit.ExitFailure 1) + Left err -> + -- we ignore the MRFailure context here since it will have already + -- been printed by the debug trace + io (putStrLn $ Prover.showMRFailureNoCtx err) >> + printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> + io (Exit.exitWith $ Exit.ExitFailure 1) + Right (Just (fnm, fassump)) | addToEnv -> + let assump_str = case Prover.fassumpRHS fassump of + Prover.OpaqueFunAssump _ _ -> "an opaque" + Prover.RewriteFunAssump _ -> "a rewrite" in + printOutLnTop Info ( + printf "[MRSolver] Success in %s, added as %s assumption" + (show diff) (assump_str :: String)) >> + modify (\rw -> rw { rwMRSolverEnv = + Prover.mrEnvAddFunAssump fnm fassump (rwMRSolverEnv rw) }) + _ -> + printOutLnTop Info $ printf "[MRSolver] Success in %s" (show diff) + +-- | Run Mr Solver to prove that the first term refines the second, returning +-- true iff this can be done. This function will not modify the 'Prover.MREnv'. +mrSolverQuery :: SharedContext -> TypedTerm -> TypedTerm -> TopLevel Bool +mrSolverQuery sc t1 t2 = + do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get + (diff, res) <- askMRSolver (Just "Querying") sc t1 t2 case res of - Left err -> io (putStrLn $ Prover.showMRFailure err) >> return False - Right env' -> put (rw { rwMRSolverEnv = env' }) >> return True + Left _ | dlvl == 0 -> + printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> + return False + Left err -> + -- we ignore the MRFailure context here since it will have already + -- been printed by the debug trace + io (putStrLn $ Prover.showMRFailureNoCtx err) >> + printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> + return False + Right _ -> + printOutLnTop Info (printf "[MRSolver] Success in %s" (show diff)) >> + return True + +-- | Set the debug level of the 'Prover.MREnv' +mrSolverSetDebug :: Int -> TopLevel () +mrSolverSetDebug dlvl = + modify (\rw -> rw { rwMRSolverEnv = + Prover.mrEnvSetDebugLevel dlvl (rwMRSolverEnv rw) }) setMonadification :: SharedContext -> String -> String -> TopLevel () setMonadification sc cry_str saw_str = diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index 06ffcdfc35..e5743e2b00 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -3259,16 +3259,41 @@ primitives = Map.fromList --------------------------------------------------------------------- - , prim "mr_solver" "Term -> Term -> TopLevel Bool" - (scVal (\sc -> mrSolver sc 0)) + , prim "mr_solver_prove" "Term -> Term -> TopLevel ()" + (scVal (mrSolverProve True)) Experimental [ "Call the monadic-recursive solver (that's MR. Solver to you)" - , " to ask if one monadic term refines another" ] + , " to prove that one monadic term refines another. If this can" + , " be done, this refinement will be used in future calls to" + , " Mr. Solver, and if it cannot, the script will exit. See also:" + , " mr_solver_test, mr_solver_query." ] - , prim "mr_solver_debug" "Int -> Term -> Term -> TopLevel Bool" - (scVal mrSolver) + , prim "mr_solver_test" "Term -> Term -> TopLevel ()" + (scVal (mrSolverProve False)) Experimental - [ "Call the monadic-recursive solver at the supplied debug level" ] + [ "Call the monadic-recursive solver (that's MR. Solver to you)" + , " to prove that one monadic term refines another. If this cannot" + , " be done, the script will exit. See also: mr_solver_prove," + , " mr_solver_query - unlike the former, this refinement will not" + , " be used in future calls to Mr. Solver." ] + + , prim "mr_solver_query" "Term -> Term -> TopLevel Bool" + (scVal mrSolverQuery) + Experimental + [ "Call the monadic-recursive solver (that's MR. Solver to you)" + , " to prove that one monadic term refines another, returning" + , " true iff this can be done. See also: mr_solver_prove," + , " mr_solver_test - unlike the former, this refinement will not" + , " be considered in future calls to Mr. Solver, and unlike both," + , " this command will never fail." ] + + , prim "mr_solver_set_debug_level" "Int -> TopLevel ()" + (pureVal mrSolverSetDebug) + Experimental + [ "Set the debug level for Mr. Solver; 0 = no debug output," + , " 1 = some debug output, 2 = all debug output" ] + + --------------------------------------------------------------------- , prim "monadify_term" "Term -> TopLevel Term" (scVal monadifyTypedTerm) diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index 759116dedf..b422cfd996 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -9,8 +9,11 @@ Portability : non-portable (language extensions) -} module SAWScript.Prover.MRSolver - (askMRSolver, MRFailure(..), showMRFailure, isCompFunType, - MREnv(..), emptyMREnv) where + (askMRSolver, MRSolverResult, + MRFailure(..), showMRFailure, showMRFailureNoCtx, + FunAssump(..), FunAssumpRHS(..), + MREnv(..), emptyMREnv, mrEnvAddFunAssump, mrEnvSetDebugLevel, + asProjAll, isCompFunType) where import SAWScript.Prover.MRSolver.Term import SAWScript.Prover.MRSolver.Monad diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index bb5a5b9148..dbb20fd7e7 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -90,6 +90,15 @@ data MRFailure pattern TermsNotEq :: Term -> Term -> MRFailure pattern TermsNotEq t1 t2 = TermsNotRel False t1 t2 +-- | Remove the context from a 'MRFailure', i.e. remove all applications of the +-- 'MRFailureLocalVar' and 'MRFailureCtx' constructors +mrFailureWithoutCtx :: MRFailure -> MRFailure +mrFailureWithoutCtx (MRFailureLocalVar _ err) = mrFailureWithoutCtx err +mrFailureWithoutCtx (MRFailureCtx _ err) = mrFailureWithoutCtx err +mrFailureWithoutCtx (MRFailureDisj err1 err2) = + MRFailureDisj (mrFailureWithoutCtx err1) (mrFailureWithoutCtx err2) +mrFailureWithoutCtx err = err + -- | Pretty-print an object prefixed with a 'String' that describes it ppWithPrefix :: PrettyInCtx a => String -> a -> PPInCtxM SawDoc ppWithPrefix str a = (pretty str <>) <$> nest 2 <$> (line <>) <$> prettyInCtx a @@ -161,6 +170,11 @@ instance PrettyInCtx MRFailure where showMRFailure :: MRFailure -> String showMRFailure = showInCtx [] +-- | Render a 'MRFailure' to a 'String' without its context (see +-- 'mrFailureWithoutCtx') +showMRFailureNoCtx :: MRFailure -> String +showMRFailureNoCtx = showMRFailure . mrFailureWithoutCtx + ---------------------------------------------------------------------- -- * MR Monad @@ -278,9 +292,7 @@ data MRInfo = MRInfo { -- note that these have the current UVars free mriAssumptions :: Term, -- | The current set of 'DataTypeAssump's - mriDataTypeAssumps :: DataTypeAssumps, - -- | The debug level, which controls debug printing - mriDebugLevel :: Int + mriDataTypeAssumps :: DataTypeAssumps } -- | State maintained by MR. Solver @@ -338,9 +350,9 @@ mrAssumptions = mriAssumptions <$> ask mrDataTypeAssumps :: MRM DataTypeAssumps mrDataTypeAssumps = mriDataTypeAssumps <$> ask --- | Get the current value of 'mriDebugLevel' +-- | Get the current debug level mrDebugLevel :: MRM Int -mrDebugLevel = mriDebugLevel <$> ask +mrDebugLevel = mreDebugLevel <$> mriEnv <$> ask -- | Get the current value of 'mriEnv' mrEnv :: MRM MREnv @@ -351,12 +363,12 @@ mrVars :: MRM MRVarMap mrVars = mrsVars <$> get -- | Run an 'MRM' computation and return a result or an error -runMRM :: SharedContext -> Maybe Integer -> Int -> MREnv -> +runMRM :: SharedContext -> Maybe Integer -> MREnv -> MRM a -> IO (Either MRFailure a) -runMRM sc timeout debug env m = +runMRM sc timeout env m = do true_tm <- scBool sc True let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout, - mriDebugLevel = debug, mriEnv = env, + mriEnv = env, mriUVars = [], mriCoIndHyps = Map.empty, mriAssumptions = true_tm, mriDataTypeAssumps = HashMap.empty } @@ -895,20 +907,12 @@ withFunAssump fname args rhs m = fname args CompFunReturn) "|=" rhs ctx <- mrUVarCtx assumps <- mrFunAssumps - let assumps' = Map.insert fname (FunAssump ctx args rhs) assumps + let assump = FunAssump ctx args (RewriteFunAssump rhs) + let assumps' = Map.insert fname assump assumps local (\info -> let env' = (mriEnv info) { mreFunAssumps = assumps' } in info { mriEnv = env' }) m --- | Generate fresh evars for the context of a 'FunAssump' and substitute them --- into its arguments and right-hand side -instantiateFunAssump :: FunAssump -> MRM ([Term], NormComp) -instantiateFunAssump fassump = - do evars <- mrFreshEVars $ fassumpCtx fassump - args <- substTermLike 0 evars $ fassumpArgs fassump - rhs <- substTermLike 0 evars $ fassumpRHS fassump - return (args, rhs) - -- | Get the invariant hint associated with a function name, by unfolding the -- name and checking if its body has the form -- diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 74b65cae1b..002317237c 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -799,12 +799,30 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = matchCoIndHyp hyp args1 args2 >> mrRefinesFun k1 k2 - -- If we have an assumption that f1 args' refines some rhs, then prove that - -- args1 = args' and then that rhs refines m2 - (_, Just fassump) -> - do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrAssertProveEq assump_args args1 - m1' <- normBind assump_rhs k1 + -- If we have an opaque FunAssump that f1 args1' refines f2 args2', then + -- prove that args1 = args1', args2 = args2', and then that k1 refines k2 + (_, Just (FunAssump ctx args1' (OpaqueFunAssump f2' args2'))) | f2 == f2' -> + do evars <- mrFreshEVars ctx + (args1'', args2'') <- substTermLike 0 evars (args1', args2') + zipWithM_ mrAssertProveEq args1'' args1 + zipWithM_ mrAssertProveEq args2'' args2 + mrRefinesFun k1 k2 + + -- If we have an opaque FunAssump that f1 refines some f /= f2, and f2 + -- unfolds and is not recursive in itself, unfold f2 and recurse + (_, Just (FunAssump _ _ (OpaqueFunAssump _ _))) + | Just (f2_body, False) <- maybe_f2_body -> + normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + + -- If we have a rewrite FunAssump, or we have an opaque FunAssump that + -- f1 args1' refines some f args where f /= f2 and f2 does not match the + -- case above, treat either case like we have a rewrite FunAssump and prove + -- that args1 = args1' and then that f args refines m2 + (_, Just (FunAssump ctx args1' (funAssumpRHSAsNormComp -> rhs))) -> + do evars <- mrFreshEVars ctx + (args1'', rhs') <- substTermLike 0 evars (args1', rhs) + zipWithM_ mrAssertProveEq args1'' args1 + m1' <- normBind rhs' k1 mrRefines m1' m2 -- If f1 unfolds and is not recursive in itself, unfold it and recurse @@ -839,10 +857,11 @@ mrRefines' m1@(FunBind f1 args1 k1) m2 = -- If we have an assumption that f1 args' refines some rhs, then prove that -- args1 = args' and then that rhs refines m2 - Just fassump -> - do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrAssertProveEq assump_args args1 - m1' <- normBind assump_rhs k1 + Just (FunAssump ctx args1' (funAssumpRHSAsNormComp -> rhs)) -> + do evars <- mrFreshEVars ctx + (args1'', rhs') <- substTermLike 0 evars (args1', rhs) + zipWithM_ mrAssertProveEq args1'' args1 + m1' <- normBind rhs' k1 mrRefines m1' m2 -- Otherwise, see if we can unfold f1 @@ -927,9 +946,14 @@ mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" -- * External Entrypoints ---------------------------------------------------------------------- +-- | The result of a successful call to Mr. Solver: either a 'FunAssump' to +-- (optionally) add to the 'MREnv', or 'Nothing' if the left-hand-side was not +-- a function name +type MRSolverResult = Maybe (FunName, FunAssump) + -- | The main loop of 'askMRSolver'. The first argument is an accumulator of -- variables to introduce, innermost first. -askMRSolverH :: [Term] -> Term -> Term -> Term -> Term -> MRM MREnv +askMRSolverH :: [Term] -> Term -> Term -> Term -> Term -> MRM MRSolverResult -- If we need to introduce a bitvector on one side and a Num on the other, -- introduce a bitvector variable and substitute `TCNum` of `bvToNat` of that @@ -1016,15 +1040,22 @@ askMRSolverH _ (asCompM -> Just _) t1 (asCompM -> Just _) t2 = do m1 <- normCompTerm t1 m2 <- normCompTerm t2 mrRefines m1 m2 - -- If t1 is a named function, add forall xs. f1 xs |= m2 to the env - case asApplyAll t1 of - ((asGlobalFunName -> Just f1), args) -> + case (m1, m2) of + -- If t1 and t2 are both named functions, our result is the opaque + -- FunAssump that forall xs. f1 xs |= f2 xs' + (FunBind f1 args1 CompFunReturn, FunBind f2 args2 CompFunReturn) -> + mrUVarCtx >>= \uvar_ctx -> + return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, + fassumpArgs = args1, + fassumpRHS = OpaqueFunAssump f2 args2 }) + -- If just t1 is a named function, our result is the rewrite FunAssump + -- that forall xs. f1 xs |= m2 + (FunBind f1 args1 CompFunReturn, _) -> mrUVarCtx >>= \uvar_ctx -> - let fassump = FunAssump { fassumpCtx = uvar_ctx, - fassumpArgs = args, - fassumpRHS = m2 } in - mrEnvAddFunAssump f1 fassump <$> mrEnv - _ -> mrEnv + return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, + fassumpArgs = args1, + fassumpRHS = RewriteFunAssump m2 }) + _ -> return Nothing -- Error if we don't have CompM at the end askMRSolverH _ (asCompM -> Just _) _ tp2 _ = @@ -1038,14 +1069,13 @@ askMRSolverH _ tp1 _ _ _ = -- environment. askMRSolver :: SharedContext -> - Int {- ^ The debug level -} -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> - Term -> Term -> IO (Either MRFailure MREnv) + Term -> Term -> IO (Either MRFailure MRSolverResult) -askMRSolver sc dlvl env timeout t1 t2 = +askMRSolver sc env timeout t1 t2 = do tp1 <- scTypeOf sc t1 >>= scWhnf sc tp2 <- scTypeOf sc t2 >>= scWhnf sc - runMRM sc timeout dlvl env $ + runMRM sc timeout env $ mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 >> askMRSolverH [] tp1 t1 tp2 t2 diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index 10f958f67b..ef093df317 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -228,6 +228,16 @@ asNonBVVecVectorType t = asVectorType t -- * Mr Solver Environments ---------------------------------------------------------------------- +-- | The right-hand-side of a 'FunAssump': either a 'FunName' and arguments, if +-- it is an opaque 'FunAsump', or a 'NormComp', if it is a rewrite 'FunAssump' +data FunAssumpRHS = OpaqueFunAssump FunName [Term] + | RewriteFunAssump NormComp + +-- | Convert a 'FunAssumpRHS' to a 'NormComp' +funAssumpRHSAsNormComp :: FunAssumpRHS -> NormComp +funAssumpRHSAsNormComp (OpaqueFunAssump f args) = FunBind f args CompFunReturn +funAssumpRHSAsNormComp (RewriteFunAssump rhs) = rhs + -- | An assumption that a named function refines some specification. This has -- the form -- @@ -244,7 +254,7 @@ data FunAssump = FunAssump { -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars fassumpArgs :: [Term], -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars - fassumpRHS :: NormComp + fassumpRHS :: FunAssumpRHS } -- | A map from function names to function refinement assumptions over that @@ -257,18 +267,24 @@ type FunAssumps = Map FunName FunAssump data MREnv = MREnv { -- | The set of function refinements to be assumed by to Mr. Solver (which -- have hopefully been proved previously...) - mreFunAssumps :: FunAssumps - } + mreFunAssumps :: FunAssumps, + -- | The debug level, which controls debug printing + mreDebugLevel :: Int +} -- | The empty 'MREnv' emptyMREnv :: MREnv -emptyMREnv = MREnv { mreFunAssumps = Map.empty } +emptyMREnv = MREnv { mreFunAssumps = Map.empty, mreDebugLevel = 0 } -- | Add a 'FunAssump' to a Mr Solver environment mrEnvAddFunAssump :: FunName -> FunAssump -> MREnv -> MREnv mrEnvAddFunAssump f fassump env = env { mreFunAssumps = Map.insert f fassump (mreFunAssumps env) } +-- | Set the debug level of a Mr Solver environment +mrEnvSetDebugLevel :: Int -> MREnv -> MREnv +mrEnvSetDebugLevel dlvl env = env { mreDebugLevel = dlvl } + ---------------------------------------------------------------------- -- * Utility Functions for Transforming 'Term's