From 26d769066393b436be4cd7efbf5ea0eb46e5cfbb Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 15 Aug 2022 16:11:23 -0700 Subject: [PATCH 01/12] fail more gracefully when `natSize` is undefined, improve panic traces --- saw-core/src/Verifier/SAW/Simulator/Prims.hs | 91 +++++++++++--------- 1 file changed, 51 insertions(+), 40 deletions(-) diff --git a/saw-core/src/Verifier/SAW/Simulator/Prims.hs b/saw-core/src/Verifier/SAW/Simulator/Prims.hs index b9426d85cf..e6b07aaf36 100644 --- a/saw-core/src/Verifier/SAW/Simulator/Prims.hs +++ b/saw-core/src/Verifier/SAW/Simulator/Prims.hs @@ -63,7 +63,8 @@ import Control.Monad.Fix (MonadFix(mfix)) import Control.Monad.Trans import Control.Monad.Trans.Maybe import Control.Monad.Trans.Except -import Data.Bifunctor +import Data.Maybe (fromMaybe) +import Data.Bitraversable import Data.Bits import Data.Map (Map) import qualified Data.Map as Map @@ -120,17 +121,6 @@ natFun = PrimFilterFun "expected Nat" r r (VCtorApp (primName -> "Prelude.Succ") [] [x]) = succ <$> (r =<< lift (force x)) r _ = mzero --- | A primitive that requires a natural argument which may or may not be --- concrete - like 'natFun' but gives the 'Value' instead of failing if the --- argument is not concrete -maybeNatFun :: VMonad l => (Either (Value l) Natural -> Prim l) -> Prim l -maybeNatFun = PrimFilterFun "expected Nat" r - where r (VNat n) = pure (Right n) - r (VCtorApp (primName -> "Prelude.Zero") [] []) = pure (Right 0) - r v@(VCtorApp (primName -> "Prelude.Succ") [] [x]) = - bimap (const v) succ <$> (r =<< lift (force x)) - r v = pure (Left v) - -- | A primitive that requires an integer argument intFun :: VMonad l => (VInt l -> Prim l) -> Prim l intFun = PrimFilterFun "expected Integer" r @@ -333,7 +323,7 @@ constMap bp = Map.fromList , ("Prelude.maxNat", maxNatOp bp) , ("Prelude.divModNat", divModNatOp bp) , ("Prelude.expNat", expNatOp) - , ("Prelude.widthNat", widthNatOp bp) + , ("Prelude.widthNat", widthNatOp) , ("Prelude.natCase", natCaseOp) , ("Prelude.equalNat", equalNatOp bp) , ("Prelude.ltNat", ltNatOp bp) @@ -416,23 +406,25 @@ toBool x = panic $ unwords ["Verifier.SAW.Simulator.toBool", show x] type Pack l = Vector (VBool l) -> MWord l type Unpack l = VWord l -> EvalM l (Vector (VBool l)) -toWord :: (VMonad l, Show (Extra l)) => Pack l -> Value l -> MWord l +toWord :: (HasCallStack, VMonad l, Show (Extra l)) + => Pack l -> Value l -> MWord l toWord _ (VWord w) = return w toWord pack (VVector vv) = pack =<< V.mapM (liftM toBool . force) vv toWord _ x = panic $ unwords ["Verifier.SAW.Simulator.toWord", show x] -toWordPred :: (VMonad l, Show (Extra l)) => Value l -> VWord l -> MBool l +toWordPred :: (HasCallStack, VMonad l, Show (Extra l)) + => Value l -> VWord l -> MBool l toWordPred (VFun _ f) = fmap toBool . f . ready . VWord toWordPred x = panic $ unwords ["Verifier.SAW.Simulator.toWordPred", show x] -toBits :: (VMonad l, Show (Extra l)) => Unpack l -> Value l -> - EvalM l (Vector (VBool l)) +toBits :: (HasCallStack, VMonad l, Show (Extra l)) + => Unpack l -> Value l -> EvalM l (Vector (VBool l)) toBits unpack (VWord w) = unpack w toBits _ (VVector v) = V.mapM (liftM toBool . force) v toBits _ x = panic $ unwords ["Verifier.SAW.Simulator.toBits", show x] -toVector :: (VMonad l, Show (Extra l)) => Unpack l - -> Value l -> ExceptT Text (EvalM l) (Vector (Thunk l)) +toVector :: (HasCallStack, VMonad l, Show (Extra l)) + => Unpack l -> Value l -> ExceptT Text (EvalM l) (Vector (Thunk l)) toVector _ (VVector v) = return v toVector unpack (VWord w) = lift (liftM (fmap (ready . VBool)) (unpack w)) toVector _ x = throwE $ "Verifier.SAW.Simulator.toVector " <> Text.pack (show x) @@ -443,7 +435,7 @@ vecIdx err v n = Just a -> a Nothing -> err -toArray :: (VMonad l, Show (Extra l)) => Value l -> MArray l +toArray :: (HasCallStack, VMonad l, Show (Extra l)) => Value l -> MArray l toArray (VArray f) = return f toArray x = panic $ unwords ["Verifier.SAW.Simulator.toArray", show x] @@ -534,13 +526,32 @@ coerceOp = -- | Return the number of bits necessary to represent the given value, -- which should be a value of type Nat. -natSize :: HasCallStack => BasePrims l -> Value l -> Natural -natSize _bp val = +natSizeMaybe :: HasCallStack => Value l -> Maybe Natural +natSizeMaybe val = case val of - VNat n -> widthNat n - VBVToNat n _ -> fromIntegral n -- TODO, remove this fromIntegral + VNat n -> Just $ widthNat n + VBVToNat n _ -> Just $ fromIntegral n -- TODO, remove this fromIntegral VIntToNat _ -> panic "natSize: symbolic integer (TODO)" - _ -> panic "natSize: expected Nat" + _ -> Nothing + +-- | Return the number of bits necessary to represent the given value, +-- which should be a value of type Nat, calling 'panic' if this cannot be done. +natSize :: (HasCallStack, Show (Extra l)) => Value l -> Natural +natSize val = fromMaybe (panic $ "natSize: expected Nat, got: " ++ show val) + (natSizeMaybe val) + +-- | A primitive that requires a natural argument, returning its value as a +-- 'Natural' if the argument is concrete, or a pair of a size in bits and a +-- 'Value', if 'natSizeMaybe' returns 'Just' +natSizeFun :: (HasCallStack, VMonad l) => + (Either (Natural, Value l) Natural -> Prim l) -> Prim l +natSizeFun = PrimFilterFun "expected Nat" r + where r (VNat n) = pure (Right n) + r (VCtorApp (primName -> "Prelude.Zero") [] []) = pure (Right 0) + r v@(VCtorApp (primName -> "Prelude.Succ") [] [x]) = + lift (force x) >>= r >>= bimapM (const (szPr v)) (pure . succ) + r v = Left <$> szPr v + szPr v = maybe mzero (pure . (,v)) (natSizeMaybe v) -- | Convert the given value (which should be of type Nat) to a word -- of the given bit-width. The bit-width must be at least as large as @@ -557,12 +568,12 @@ natToWord bp w val = VBVToNat xsize v -> do x <- toWord (bpPack bp) v case compare xsize (fromIntegral w) of - GT -> panic "natToWord: not enough bits" + GT -> panic $ "natToWord: not enough bits for: " ++ show val EQ -> return x LT -> -- zero-extend x to width w do pad <- bpBvLit bp (fromIntegral w - xsize) 0 bpBvJoin bp pad x - _ -> panic "natToWord: expected Nat" + _ -> panic $ "natToWord: expected Nat, got: " ++ show val -- | A primitive which is a unary operation on a natural argument. -- The second argument gives how to modify the size in bits of this operation's @@ -574,11 +585,11 @@ unaryNatOp :: (HasCallStack, VMonad l, Show (Extra l)) => BasePrims l -> (Natural -> Natural) -> (Natural -> MValue l) -> (Int -> VWord l -> MValue l) -> Prim l -unaryNatOp bp fw fn fv = maybeNatFun $ \case +unaryNatOp bp fw fn fv = natSizeFun $ \case Right n -> Prim (fn n) - Left v -> Prim $ do let w = fw (natSize bp v) - x <- natToWord bp w v - fv (fromIntegral w) x + Left (w1,v) -> Prim $ do let w = fw w1 + x <- natToWord bp w v + fv (fromIntegral w) x -- | A primitive which is a unary operation on a natural argument and which -- returns a natural. @@ -607,12 +618,12 @@ binNatOp :: (HasCallStack, VMonad l, Show (Extra l)) => BasePrims l -> (Natural -> Natural -> Natural) -> (Natural -> Natural -> MValue l) -> (Int -> VWord l -> VWord l -> MValue l) -> Prim l -binNatOp bp fw fn fv = maybeNatFun $ \m -> maybeNatFun $ \n -> go m n +binNatOp bp fw fn fv = natSizeFun (natSizeFun . go) where go (Right m) (Right n) = Prim (fn m n) - go (Right m) (Left v2) = go (Left (VNat m)) (Left v2) - go (Left v1) (Right n) = go (Left v1) (Left (VNat n)) - go (Left v1) (Left v2) = Prim $ - do let w = fw (natSize bp v1) (natSize bp v2) + go (Right m) (Left pr) = go (Left (widthNat m, VNat m)) (Left pr) + go (Left pr) (Right n) = go (Left pr) (Left (widthNat n, VNat n)) + go (Left (w1,v1)) (Left (w2,v2)) = Prim $ + do let w = fw w1 w2 x1 <- natToWord bp w v1 x2 <- natToWord bp w v2 fv (fromIntegral w) x1 x2 @@ -686,10 +697,10 @@ expNatOp = PrimValue (vNat (m ^ n)) -- widthNat :: Nat -> Nat; -widthNatOp :: (HasCallStack, VMonad l, Show (Extra l)) => BasePrims l -> Prim l -widthNatOp bp = maybeNatFun $ \case +widthNatOp :: (HasCallStack, VMonad l, Show (Extra l)) => Prim l +widthNatOp = natSizeFun $ \case Right n -> PrimValue (vNat (widthNat n)) - Left v -> PrimValue (vNat (natSize bp v)) + Left (w,_) -> PrimValue (vNat w) -- equalNat :: Nat -> Nat -> Bool; equalNatOp :: (HasCallStack, VMonad l, Show (Extra l)) => BasePrims l -> Prim l @@ -1336,7 +1347,7 @@ muxValue bp tp0 b = value tp0 nat :: Value l -> Value l -> MValue l nat v1 v2 = - do let w = max (natSize bp v1) (natSize bp v2) + do let w = max (natSize v1) (natSize v2) x1 <- natToWord bp w v1 x2 <- natToWord bp w v2 VBVToNat (fromIntegral w) . VWord <$> bpMuxWord bp b x1 x2 From 9fe96d3427af6b06312ff92d4a864621ddbfa5a5 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 15 Aug 2022 16:11:52 -0700 Subject: [PATCH 02/12] add implementation of scSplitM, ecJoinM --- cryptol-saw-core/saw/CryptolM.sawcore | 59 +++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/cryptol-saw-core/saw/CryptolM.sawcore b/cryptol-saw-core/saw/CryptolM.sawcore index 59462730db..e85580616a 100644 --- a/cryptol-saw-core/saw/CryptolM.sawcore +++ b/cryptol-saw-core/saw/CryptolM.sawcore @@ -409,14 +409,65 @@ primitive ecDropM : (m : Num) -> isFinite m -> (n : Num) -> (a : sort 0) -> mseq (tcAdd m n) a -> mseq n a; --- FIXME -primitive ecJoinM : (m n : Num) -> (a : sort 0) -> mseq m (mseq n a) -> mseq (tcMul m n) a; +ecJoinM = + Num_rec + (\ (m:Num) -> (n:Num) -> (a:isort 0) -> mseq m (mseq n a) -> + mseq (tcMul m n) a) + (\ (m:Nat) -> + finNumRec + (\ (n:Num) -> (a:isort 0) -> Vec m (mseq n a) -> + mseq (tcMul (TCNum m) n) a) + -- Case for (TCNum m, TCNum n) + (\ (n:Nat) -> \ (a:isort 0) -> join m n a)) + -- No case for (TCNum m, TCInf), shoudn't happen + (finNumRec + (\ (n:Num) -> (a:isort 0) -> Stream (CompM (mseq n a)) -> + mseq (tcMul TCInf n) a) + -- Case for (TCInf, TCNum n) + (\ (n:Nat) -> \ (a:isort 0) -> + natCase + (\ (n':Nat) -> Stream (CompM (Vec n' a)) -> + mseq (if0Nat Num n' (TCNum 0) TCInf) a) + (\ (s:Stream (CompM (Vec 0 a))) -> EmptyVec a) + (\ (n':Nat) -> \ (s:Stream (CompM (Vec (Succ n') a))) -> + MkStream (CompM a) (\ (i:Nat) -> + fmapM (Vec (Succ n') a) a + (\ (v:Vec (Succ n') a) -> at (Succ n') a v (modNat i (Succ n'))) + (streamGet (CompM (Vec (Succ n') a)) s (divNat i (Succ n'))) )) + n)); + -- No case for (TCInf, TCInf), shouldn't happen --- FIXME -primitive ecSplitM : (m n : Num) -> (a : sort 0) -> mseq (tcMul m n) a -> mseq m (mseq n a); +ecSplitM = + Num_rec + (\ (m:Num) -> (n:Num) -> (a:isort 0) -> mseq (tcMul m n) a -> + mseq m (mseq n a)) + (\ (m:Nat) -> + finNumRec + (\ (n:Num) -> (a:isort 0) -> mseq (tcMul (TCNum m) n) a -> + Vec m (mseq n a)) + -- Case for (TCNum m, TCNum n) + (\ (n:Nat) -> \ (a:isort 0) -> split m n a)) + -- No case for (TCNum m, TCInf), shouldn't happen + (finNumRec + (\ (n:Num) -> (a:isort 0) -> mseq (tcMul TCInf n) a -> + Stream (CompM (mseq n a))) + -- Case for (TCInf, TCNum n) + (\ (n:Nat) -> \ (a:isort 0) -> + natCase + (\ (n':Nat) -> + mseq (if0Nat Num n' (TCNum 0) TCInf) a -> + Stream (CompM (Vec n' a))) + (\ (xs : Vec 0 a) -> streamConst (CompM (Vec 0 a)) + (returnM (Vec 0 a) xs)) + (\ (n':Nat) (xs : Stream (CompM a)) -> + streamMap (Vec (Succ n') (CompM a)) (CompM (Vec (Succ n') a)) + (vecMapM (CompM a) a (Succ n') (id (CompM a))) + (streamSplit (CompM a) (Succ n') xs)) + n)); + -- No case for (TCInf, TCInf), shouldn't happen ecReverseM : (n : Num) -> isFinite n -> (a : sort 0) -> mseq n a -> mseq n a; ecReverseM = From af3e84178ee462d1e3f6fe95c9d08c44cb181cb2 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 15 Aug 2022 16:16:28 -0700 Subject: [PATCH 03/12] refactor mrRefinesFunH to use types directly, use mrProveEq in BVVec het --- src/SAWScript/Prover/MRSolver/Monad.hs | 22 ++--- src/SAWScript/Prover/MRSolver/SMT.hs | 39 +++++++-- src/SAWScript/Prover/MRSolver/Solver.hs | 109 +++++++++++------------- src/SAWScript/Prover/MRSolver/Term.hs | 19 +++-- 4 files changed, 99 insertions(+), 90 deletions(-) diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 4c16112756..9db935e10d 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -507,19 +507,6 @@ matchHet (asPairType -> Just (tpL1, tpR1)) Just $ HetPair (tpL1, tpR1) (tpL2, tpR2) matchHet _ _ = Nothing --- | Return true iff the given types are heterogeneously related -typesHetRelated :: Term -> Term -> MRM Bool -typesHetRelated tp1 tp2 = case matchHet tp1 tp2 of - Just (HetBVNum _) -> return True - Just (HetNumBV _) -> return True - Just (HetBVVecVec (n, len, a) (m, a')) -> mrBvToNat n len >>= \m' -> - (&&) <$> mrConvertible m m' <*> typesHetRelated a a' - Just (HetVecBVVec (m, a') (n, len, a)) -> mrBvToNat n len >>= \m' -> - (&&) <$> mrConvertible m m' <*> typesHetRelated a a' - Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) -> - (&&) <$> typesHetRelated tpL1 tpL2 <*> typesHetRelated tpR1 tpR2 - Nothing -> mrConvertible tp1 tp2 - ---------------------------------------------------------------------- -- * Functions for Building Terms @@ -597,13 +584,13 @@ mrCtorApp = liftSC2 scCtorApp mrGlobalTerm :: Ident -> MRM Term mrGlobalTerm = liftSC1 scGlobalDef --- | Like 'scBvNat', but if given a bitvector literal it is converted to a +-- | Like 'scBvConst', but if given a bitvector literal it is converted to a -- natural number literal mrBvToNat :: Term -> Term -> MRM Term mrBvToNat _ (asArrayValue -> Just (asBoolType -> Just _, mapM asBool -> Just bits)) = liftSC1 scNat $ foldl' (\n bit -> if bit then 2*n+1 else 2*n) 0 bits -mrBvToNat n len = liftSC2 scBvNat n len +mrBvToNat n len = liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] -- | Get the current context of uvars as a list of variable names and their -- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in @@ -622,7 +609,7 @@ mrTypeOf :: Term -> MRM Term mrTypeOf t = -- NOTE: scTypeOf' wants the type context in the most recently bound var -- first, i.e., in the mrUVarCtxRev order - mrDebugPPPrefix 3 "mrTypeOf:" t >> + -- mrDebugPPPrefix 3 "mrTypeOf:" t >> mrUVarCtxRev >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t -- | Check if two 'Term's are convertible in the 'MRM' monad @@ -692,13 +679,14 @@ withUVars :: [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a withUVars [] f = f [] withUVars ctx f = do nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVars - let ctx_u = zip nms $ map (Type . snd) ctx + ctx_u <- zip nms <$> mapM (liftTermLike 0 (length ctx) . Type . snd) ctx assumps' <- mrAssumptions >>= liftTerm 0 (length ctx) dataTypeAssumps' <- mrDataTypeAssumps >>= mapM (liftTermLike 0 (length ctx)) vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. length ctx - 1] local (\info -> info { mriUVars = reverse ctx_u ++ mriUVars info, mriAssumptions = assumps', mriDataTypeAssumps = dataTypeAssumps' }) $ + mrDebugPPPrefix 3 "withUVars:" ctx_u >> foldr (\nm m -> mapMRFailure (MRFailureLocalVar nm) m) (f vars) nms -- | Run a MR Solver in a top-level context, i.e., with no uvars or assumptions diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 29142dc90c..9bd6b2d61d 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -163,7 +163,7 @@ bvVecFromBVVecOrLit sc n n' len a (BVVecLit vs) = i_tp <- scBitvector sc n var0 <- scLocalVar sc 0 pf_tp <- scGlobalApply sc "Prelude.is_bvult" [n', var0, len] - f <- scLambdaList sc [("i", i_tp), ("pf", pf_tp)] body + f <- scLambdaList sc [("i", i_tp), ("pf", pf_tp)] body scGlobalApply sc "Prelude.genBVVec" [n', len, a, f] where mkBody :: Integer -> [Term] -> IO Term mkBody _ [] = error "bvVecFromBVVecOrLit: empty vector" @@ -310,14 +310,17 @@ mrProvableRaw prop_term = mrProvable :: Term -> MRM Bool mrProvable (asBool -> Just b) = return b mrProvable bool_tm = - do assumps <- mrAssumptions + do uvarCtx <- mrUVarCtx + debugPretty 3 $ "mrProvable uvars:" <> ppCtx uvarCtx + assumps <- mrAssumptions prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue prop_inst <- mrSubstEVars prop >>= instantiateUVarsM instUVar mrNormTerm prop_inst >>= mrProvableRaw where -- | Given a UVar name and type, generate a 'Term' to be passed to -- SMT, with special cases for BVVec and pair types instUVar :: LocalName -> Term -> MRM Term - instUVar nm tp = liftSC1 scWhnf tp >>= \case + instUVar nm tp = mrDebugPPPrefix 3 "instUVar" (nm, tp) >> + liftSC1 scWhnf tp >>= \case -- For variables of type BVVec, create a @Vec n Bool -> a@ function -- as an ExtCns and apply genBVVec to it (asBVVecType -> Just (n, len, a)) -> do @@ -337,6 +340,24 @@ mrProvable bool_tm = tp' -> liftSC2 scFreshEC nm tp' >>= liftSC1 scExtCns +---------------------------------------------------------------------- +-- * Relating Types Heterogeneously with SMT +---------------------------------------------------------------------- + +-- | Return true iff the given types are heterogeneously related +typesHetRelated :: Term -> Term -> MRM Bool +typesHetRelated tp1 tp2 = case matchHet tp1 tp2 of + Just (HetBVNum _) -> return True + Just (HetNumBV _) -> return True + Just (HetBVVecVec (n, len, a) (m, a')) -> mrBvToNat n len >>= \m' -> + (&&) <$> mrProveEq m m' <*> typesHetRelated a a' + Just (HetVecBVVec (m, a') (n, len, a)) -> mrBvToNat n len >>= \m' -> + (&&) <$> mrProveEq m m' <*> typesHetRelated a a' + Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) -> + (&&) <$> typesHetRelated tpL1 tpL2 <*> typesHetRelated tpR1 tpR2 + Nothing -> mrConvertible tp1 tp2 + + ---------------------------------------------------------------------- -- * Checking Equality with SMT ---------------------------------------------------------------------- @@ -580,7 +601,7 @@ mrProveRelH' _ True tp1 tp2 t1 t2 | Just mh <- matchHet tp1 tp2 = case mh of -- genBVVecFromVec and recurse HetBVVecVec (n, len, _) (m, tpA2) -> do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m' m + ms_are_eq <- mrProveEq m' m if ms_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] @@ -590,7 +611,7 @@ mrProveRelH' _ True tp1 tp2 t1 t2 | Just mh <- matchHet tp1 tp2 = case mh of mrProveRelH True tp1 tp2' t1 t2' HetVecBVVec (m, tpA1) (n, len, _) -> do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m' m + ms_are_eq <- mrProveEq m' m if ms_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] @@ -627,8 +648,10 @@ mrProveRelH' _ False tp1 tp2 t1 t2 -- As a fallback, for types we can't handle, just check convertibility mrProveRelH' _ het tp1 tp2 t1 t2 = do success <- mrConvertible t1 t2 + tps_eq <- mrConvertible tp1 tp2 if success then return () else - if het then mrDebugPPPrefixSep 2 "mrProveRelH' could not match types: " tp1 "and" tp2 >> - mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 - else mrDebugPPPrefixSep 2 "mrProveEq could not prove convertible: " t1 "and" t2 + if het || not tps_eq + then mrDebugPPPrefixSep 2 "mrProveRelH' could not match types: " tp1 "and" tp2 >> + mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 + else mrDebugPPPrefixSep 2 "mrProveEq could not prove convertible: " t1 "and" t2 TermInCtx [] <$> liftSC1 scBool success diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 315df6e2e2..f1ae5801f4 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -662,7 +662,7 @@ generalizeCoIndHypArgs hyp [(specs1, tp1), (specs2, tp2)] = case matchHet tp1 tp -- FIXME: Could we handle the a /= a' case here and in mrRefinesFunH? Just (HetBVVecVec (n, len, a) (m, a')) -> do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m m' + ms_are_eq <- mrProveEq m m' as_are_eq <- mrConvertible a a' if ms_are_eq && as_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) @@ -672,7 +672,7 @@ generalizeCoIndHypArgs hyp [(specs1, tp1), (specs2, tp2)] = case matchHet tp1 tp return $ coIndHypSetArgs hyp'' specs2 bvv_tm Just (HetVecBVVec (m, a') (n, len, a)) -> do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m m' + ms_are_eq <- mrProveEq m m' as_are_eq <- mrConvertible a a' if ms_are_eq && as_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) @@ -1056,9 +1056,13 @@ mrRefinesFun tp1 f1 tp2 f2 = mrDebugPPPrefixSep 1 "mrRefinesFun" f1 "|=" f2 f1' <- compFunToTerm f1 >>= liftSC1 scWhnf f2' <- compFunToTerm f2 >>= liftSC1 scWhnf - let lnm = maybe "call_ret_val" id (compFunVarName f1) - rnm = maybe "call_ret_val" id (compFunVarName f2) - mrRefinesFunH mrRefines [] [(lnm, tp1)] f1' [(rnm, tp2)] f2' + let nm1 = maybe "call_ret_val" id (compFunVarName f1) + nm2 = maybe "call_ret_val" id (compFunVarName f2) + f1'' <- mrLambdaLift [(nm1, tp1)] f1' $ \[var] -> flip mrApply var + f2'' <- mrLambdaLift [(nm2, tp2)] f2' $ \[var] -> flip mrApply var + tps1 <- mrTypeOf f1'' + tps2 <- mrTypeOf f2'' + mrRefinesFunH mrRefines [] tps1 f1'' tps2 f2'' -- | The main loop of 'mrRefinesFun' and 'askMRSolver': given a continuation, -- two terms of function type, and two equal-length lists representing the @@ -1069,10 +1073,7 @@ mrRefinesFun tp1 f1 tp2 f2 = -- and call the continuation on the resulting terms. The second argument is -- an accumulator of variables to introduce, innermost first. mrRefinesFunH :: (Term -> Term -> MRM a) -> [Term] -> - [(LocalName,Term)] -> Term -> [(LocalName,Term)] -> Term -> - MRM a - -mrRefinesFunH k vars ((nm1, tp1):tps1) t1 ((nm2, tp2):tps2) t2 = case matchHet tp1 tp2 of + Term -> Term -> Term -> Term -> MRM a -- If we need to introduce a bitvector on one side and a Num on the other, -- introduce a bitvector variable and substitute `TCNum` of `bvToNat` of that @@ -1084,13 +1085,11 @@ mrRefinesFunH k vars ((nm1, tp1):tps1) t1 ((nm2, tp2):tps2) t2 = case matchHet t withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> do nat_tm <- liftSC2 scBvToNat n var num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - tps2' <- zipWithM (\i tp -> liftTermLike 0 i num_tm >>= \num_tm' -> - substTermLike i (num_tm' : vars') tp >>= - mapM (liftSC1 scWhnf)) - [0..] tps2 t1'' <- mrApplyAll t1' [var] t2'' <- mrApplyAll t2' [num_tm] - mrRefinesFunH k (var : vars') tps1 t1'' tps2' t2'' + tps1' <- mrTypeOf t1'' + tps2' <- mrTypeOf t2'' >>= liftSC1 scWhnf + mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' Just (HetNumBV n) -> let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 @@ -1098,65 +1097,59 @@ mrRefinesFunH k vars ((nm1, tp1):tps1) t1 ((nm2, tp2):tps2) t2 = case matchHet t withUVarLift nm (Type tp2) (vars, t1, t2) $ \var (vars', t1', t2') -> do nat_tm <- liftSC2 scBvToNat n var num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - tps1' <- zipWithM (\i tp -> liftTermLike 0 i num_tm >>= \num_tm' -> - substTermLike i (num_tm' : vars') tp >>= - mapM (liftSC1 scWhnf)) - [0..] tps1 t1'' <- mrApplyAll t1' [num_tm] t2'' <- mrApplyAll t2' [var] - mrRefinesFunH k (var : vars') tps1' t1'' tps2 t2'' + tps1' <- mrTypeOf t1'' >>= liftSC1 scWhnf + tps2' <- mrTypeOf t2'' + mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' -- If we need to introduce a BVVec on one side and a non-BVVec vector on the -- other, introduce a BVVec variable and substitute `genBVVecFromVec` of that -- variable on the non-BVVec side -- FIXME: Could we handle the a /= a' case here and in generalizeCoIndHypArgs? Just (HetBVVecVec (n, len, a) (m, a')) -> - do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m m' + do lenNat <- mrBvToNat n len + ms_are_eq <- mrProveEq m lenNat as_are_eq <- mrConvertible a a' if ms_are_eq && as_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 , asLambdaName t2 ] - withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> - do bvv_tm <- mrGenFromBVVec n len a var "mrRefinesFunH (BVVec/Vec)" m - tps2' <- zipWithM (\i tp -> liftTermLike 0 i bvv_tm >>= \bvv_tm' -> - substTermLike i (bvv_tm' : vars') tp >>= - mapM (liftSC1 scWhnf)) - [0..] tps2 + withUVarLift nm (Type tp1) (vars, n, len, a, m, t1, t2) $ \var (vars', n', len', a', m', t1', t2') -> + do bvv_tm <- mrGenFromBVVec n' len' a' var "mrRefinesFunH (BVVec/Vec)" m' t1'' <- mrApplyAll t1' [var] t2'' <- mrApplyAll t2' [bvv_tm] - mrRefinesFunH k (var : vars') tps1 t1'' tps2' t2'' + tps1' <- mrTypeOf t1'' + tps2' <- mrTypeOf t2'' >>= liftSC1 scWhnf + mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' Just (HetVecBVVec (m, a') (n, len, a)) -> - do m' <- mrBvToNat n len - ms_are_eq <- mrConvertible m m' + do lenNat <- mrBvToNat n len + ms_are_eq <- mrProveEq m lenNat as_are_eq <- mrConvertible a a' if ms_are_eq && as_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 , asLambdaName t2 ] - withUVarLift nm (Type tp2) (vars, t1, t2) $ \var (vars', t1', t2') -> - do bvv_tm <- mrGenFromBVVec n len a var "mrRefinesFunH (BVVec/Vec)" m - tps1' <- zipWithM (\i tp -> liftTermLike 0 i bvv_tm >>= \bvv_tm' -> - substTermLike i (bvv_tm' : vars') tp >>= - mapM (liftSC1 scWhnf)) - [0..] tps1 - t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [bvv_tm] - mrRefinesFunH k (var : vars') tps1' t1'' tps2 t2'' + withUVarLift nm (Type tp2) (vars, n, len, a, m, t1, t2) $ \var (vars', n', len', a', m', t1', t2') -> + do bvv_tm <- mrGenFromBVVec n' len' a' var "mrRefinesFunH (BVVec/Vec)" m' + t1'' <- mrApplyAll t1' [bvv_tm] + t2'' <- mrApplyAll t2' [var] + tps1' <- mrTypeOf t1'' >>= liftSC1 scWhnf + tps2' <- mrTypeOf t2'' + mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' -- We always curry pair values before introducing them (NOTE: we do this even -- when the have the same types to ensure we never have to unify a projection -- of an evar with a non-projected value, i.e. evar.1 == val ) Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) -> - do let tps1' = (nm1 <> "_1", tpL1):(nm1 <> "_2", tpR1):tps1 - tps2' = (nm2 <> "_1", tpL2):(nm2 <> "_2", tpR2):tps2 - t1'' <- mrLambdaLift [(nm1, tpL1), (nm1, tpR1)] t1 $ \[prj1, prj2] t1' -> + do t1'' <- mrLambdaLift [(nm1, tpL1), (nm1, tpR1)] t1 $ \[prj1, prj2] t1' -> liftSC2 scPairValue prj1 prj2 >>= mrApply t1' t2'' <- mrLambdaLift [(nm2, tpL2), (nm2, tpR2)] t2 $ \[prj1, prj2] t2' -> liftSC2 scPairValue prj1 prj2 >>= mrApply t2' + tps1' <- mrTypeOf t1'' + tps2' <- mrTypeOf t2'' mrRefinesFunH k vars tps1' t1'' tps2' t2'' -- Introduce variables of the same type together @@ -1170,18 +1163,26 @@ mrRefinesFunH k vars ((nm1, tp1):tps1) t1 ((nm2, tp2):tps2) t2 = case matchHet t withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> do t1'' <- mrApplyAll t1' [var] t2'' <- mrApplyAll t2' [var] - mrRefinesFunH k (var : vars') tps1 t1'' tps2 t2'' + tps1' <- mrTypeOf t1'' + tps2' <- mrTypeOf t2'' + mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' -- Error if we don't have the same number of arguments on both sides -- FIXME: Add a specific error for this case -mrRefinesFunH _ _ ((_,tp1):_) _ [] _ = +mrRefinesFunH _ _ (asPi -> Just (_,tp1,_)) _ (asPi -> Nothing) _ = liftSC0 scUnitType >>= \utp -> throwMRFailure (TypesNotEq (Type tp1) (Type utp)) -mrRefinesFunH _ _ [] _ ((_,tp2):_) _ = +mrRefinesFunH _ _ (asPi -> Nothing) _ (asPi -> Just (_,tp2,_)) _ = liftSC0 scUnitType >>= \utp -> throwMRFailure (TypesNotEq (Type utp) (Type tp2)) -mrRefinesFunH k _ [] t1 [] t2 = k t1 t2 +-- Error if either side's return type is not CompM +mrRefinesFunH _ _ tp1@(asCompM -> Nothing) _ _ _ = + throwMRFailure (NotCompFunType tp1) +mrRefinesFunH _ _ _ _ tp2@(asCompM -> Nothing) _ = + throwMRFailure (NotCompFunType tp2) + +mrRefinesFunH k _ _ t1 _ t2 = k t1 t2 ---------------------------------------------------------------------- @@ -1232,13 +1233,7 @@ askMRSolver sc env timeout t1 t2 = tp2 <- scTypeOf sc t2 >>= scWhnf sc runMRM sc timeout env $ mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 >> - case (asPiList tp1, asPiList tp2) of - ((tps1, asCompM -> Just _), (tps2, asCompM -> Just _)) -> - mrRefinesFunH (askMRSolverH mrRefines) [] tps1 t1 tps2 t2 - ((_, asCompM -> Just _), (_, tp2')) -> - throwMRFailure (NotCompFunType tp2') - ((_, tp1'), _) -> - throwMRFailure (NotCompFunType tp1') + mrRefinesFunH (askMRSolverH mrRefines) [] tp1 t1 tp2 t2 -- | Return the 'FunAssump' to add to the 'MREnv' that would be generated if -- 'askMRSolver' succeeded on the given terms. @@ -1251,10 +1246,4 @@ assumeMRSolver sc env timeout t1 t2 = do tp1 <- scTypeOf sc t1 >>= scWhnf sc tp2 <- scTypeOf sc t2 >>= scWhnf sc runMRM sc timeout env $ - case (asPiList tp1, asPiList tp2) of - ((tps1, asCompM -> Just _), (tps2, asCompM -> Just _)) -> - mrRefinesFunH (askMRSolverH (\_ _ -> return ())) [] tps1 t1 tps2 t2 - ((_, asCompM -> Just _), (_, tp2')) -> - throwMRFailure (NotCompFunType tp2') - ((_, tp1'), _) -> - throwMRFailure (NotCompFunType tp1') + mrRefinesFunH (askMRSolverH (\_ _ -> return ())) [] tp1 t1 tp2 t2 diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index b29373b064..687311e4f1 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -36,6 +36,7 @@ import qualified Data.IntMap as IntMap import GHC.Generics import Prettyprinter +import Data.Text (Text, unpack) import Data.Map (Map) import qualified Data.Map as Map @@ -209,19 +210,24 @@ asIsFinite (asApp -> Just (isGlobalDef "CryptolM.isFinite" -> Just (), n)) = Just n asIsFinite _ = Nothing --- | Test if a 'Term' is a 'BVVec' type +-- | Test if a 'Term' is a 'BVVec' type, excluding bitvectors asBVVecType :: Recognizer Term (Term, Term, Term) asBVVecType (asApplyAll -> (isGlobalDef "Prelude.Vec" -> Just _, [(asApplyAll -> - (isGlobalDef "Prelude.bvToNat" -> Just _, [n, len])), a])) = - Just (n, len, a) + (isGlobalDef "Prelude.bvToNat" -> Just _, [n, len])), a])) + | Just _ <- asBoolType a = Nothing + | otherwise = Just (n, len, a) asBVVecType _ = Nothing --- | Like 'asVectorType', but returns 'Nothing' if 'asBVVecType' returns 'Just' +-- | Like 'asVectorType', but returns 'Nothing' if 'asBVVecType' returns +-- 'Just' or if the given 'Term' is a bitvector type asNonBVVecVectorType :: Recognizer Term (Term, Term) asNonBVVecVectorType (asBVVecType -> Just _) = Nothing -asNonBVVecVectorType t = asVectorType t +asNonBVVecVectorType (asVectorType -> Just (n, a)) + | Just _ <- asBoolType a = Nothing + | otherwise = Just (n, a) +asNonBVVecVectorType _ = Nothing -- | Like 'asLambda', but only return's the lambda-bound variable's 'LocalName' asLambdaName :: Recognizer Term LocalName @@ -454,6 +460,9 @@ instance PrettyInCtx a => PrettyInCtx [a] where instance {-# OVERLAPPING #-} PrettyInCtx String where prettyInCtx str = return $ fromString str +instance PrettyInCtx Text where + prettyInCtx str = return $ fromString $ unpack str + instance PrettyInCtx Int where prettyInCtx i = return $ viaShow i From 3c773742a395ef5e9571517d4099ba2875d4277f Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 15 Aug 2022 16:16:49 -0700 Subject: [PATCH 04/12] introduce equality arguments as assumptions --- heapster-saw/examples/sha512.saw | 2 +- src/SAWScript/Prover/MRSolver/Solver.hs | 34 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/heapster-saw/examples/sha512.saw b/heapster-saw/examples/sha512.saw index 6624a9f6fc..cc67a6479e 100644 --- a/heapster-saw/examples/sha512.saw +++ b/heapster-saw/examples/sha512.saw @@ -70,7 +70,7 @@ heapster_typecheck_fun env "processBlock" heapster_set_translation_checks env false; heapster_typecheck_fun env "processBlocks" "(num:bv 64). arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ - \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ + \ arg1:(num )), \ \ arg2:eq(llvmword(num)) -o \ \ arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index f1ae5801f4..a1a7f3c252 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -1074,6 +1074,40 @@ mrRefinesFun tp1 f1 tp2 f2 = -- an accumulator of variables to introduce, innermost first. mrRefinesFunH :: (Term -> Term -> MRM a) -> [Term] -> Term -> Term -> Term -> Term -> MRM a + +-- Introduce equalities on either side as assumptions +mrRefinesFunH k vars tps1@(asPi -> Just (nm1, tp1@(asEq -> Just (asBoolType -> Just (), b1, b2)), _)) t1 tps2 t2 = + mrUVars >>= mrDebugPPPrefix 3 "mrRefinesFunH uvars:" >> + mrDebugPPPrefixSep 3 "mrRefinesFunH types" tps1 "|=" tps2 >> + mrDebugPPPrefixSep 3 "mrRefinesFunH" t1 "|=" t2 >> + liftSC2 scBoolEq b1 b2 >>= \eq -> + withAssumption eq $ + let nm = maybe "_" id $ find ((/=) '_' . Text.head) + $ [nm1] ++ catMaybes [ asLambdaName t1 ] in + withUVarLift nm (Type tp1) (vars, t1, tps2, t2) $ \var (vars', t1', tps2', t2') -> + do t1'' <- mrApplyAll t1' [var] + tps1' <- mrTypeOf t1'' + mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2' +mrRefinesFunH k vars tps1 t1 tps2@(asPi -> Just (nm2, tp2@(asEq -> Just (asBoolType -> Just (), b1, b2)), _)) t2 = + mrUVars >>= mrDebugPPPrefix 3 "mrRefinesFunH uvars:" >> + mrDebugPPPrefixSep 3 "mrRefinesFunH types" tps1 "|=" tps2 >> + mrDebugPPPrefixSep 3 "mrRefinesFunH" t1 "|=" t2 >> + liftSC2 scBoolEq b1 b2 >>= \eq -> + withAssumption eq $ + let nm = maybe "_" id $ find ((/=) '_' . Text.head) + $ [nm2] ++ catMaybes [ asLambdaName t2 ] in + withUVarLift nm (Type tp2) (vars, tps1, t1, t2) $ \var (vars', tps1', t1', t2') -> + do t2'' <- mrApplyAll t2' [var] + tps2' <- mrTypeOf t2'' + mrRefinesFunH k (var : vars') tps1' t1' tps2' t2'' + +mrRefinesFunH k vars tps1@(asPi -> Just (nm1, tp1, _)) t1 + tps2@(asPi -> Just (nm2, tp2, _)) t2 = + mrUVarCtx >>= \uvarCtx -> + debugPretty 3 ("mrRefinesFunH uvars:" <> ppCtx uvarCtx) >> + mrDebugPPPrefixSep 3 "mrRefinesFunH types" tps1 "|=" tps2 >> + mrDebugPPPrefixSep 3 "mrRefinesFunH" t1 "|=" t2 >> + case matchHet tp1 tp2 of -- If we need to introduce a bitvector on one side and a Num on the other, -- introduce a bitvector variable and substitute `TCNum` of `bvToNat` of that From c18c121267f6072302f51cc804719e4176d32091 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 15 Aug 2022 16:16:58 -0700 Subject: [PATCH 05/12] WIP on processBlocks --- heapster-saw/examples/sha512_mr_solver.saw | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index bd7ea87192..911d92f0a2 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -27,7 +27,8 @@ monadify_term {{ processBlock_spec }}; monadify_term {{ processBlocks_loop_spec }}; monadify_term {{ processBlocks_spec }}; -mr_solver_prove round_00_15 {{ round_00_15_spec }}; -mr_solver_prove round_16_80 {{ round_16_80_spec }}; -mr_solver_prove processBlock {{ processBlock_spec }}; -// mr_solver_prove processBlocks {{ processBlocks_spec }}; +mr_solver_set_debug_level 3; +mr_solver_assume round_00_15 {{ round_00_15_spec }}; +mr_solver_assume round_16_80 {{ round_16_80_spec }}; +mr_solver_assume processBlock {{ processBlock_spec }}; +mr_solver_prove processBlocks {{ processBlocks_spec }}; From 59dfb09096cc435122e199f87961551047290e7c Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 15 Aug 2022 17:03:34 -0700 Subject: [PATCH 06/12] revert lifting of uvar ctx in withUVars --- src/SAWScript/Prover/MRSolver/Monad.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 9db935e10d..ac62470757 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -679,7 +679,7 @@ withUVars :: [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a withUVars [] f = f [] withUVars ctx f = do nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVars - ctx_u <- zip nms <$> mapM (liftTermLike 0 (length ctx) . Type . snd) ctx + let ctx_u = zip nms $ map (Type . snd) ctx assumps' <- mrAssumptions >>= liftTerm 0 (length ctx) dataTypeAssumps' <- mrDataTypeAssumps >>= mapM (liftTermLike 0 (length ctx)) vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. length ctx - 1] From 7de0965875bab7e7f33792991ea6f771adbf8e15 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Tue, 16 Aug 2022 12:55:18 -0700 Subject: [PATCH 07/12] add MRVarCtx to avoid ambiguity about var ctx orderings --- src/SAWScript/Prover/MRSolver/Monad.hs | 113 +++++++++++++----------- src/SAWScript/Prover/MRSolver/SMT.hs | 9 +- src/SAWScript/Prover/MRSolver/Solver.hs | 11 ++- src/SAWScript/Prover/MRSolver/Term.hs | 103 ++++++++++++++++----- 4 files changed, 149 insertions(+), 87 deletions(-) diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index ac62470757..9c751cc30d 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -177,7 +177,7 @@ instance PrettyInCtx MRFailure where -- | Render a 'MRFailure' to a 'String' showMRFailure :: MRFailure -> String -showMRFailure = showInCtx [] +showMRFailure = showInCtx emptyMRVarCtx -- | Render a 'MRFailure' to a 'String' without its context (see -- 'mrFailureWithoutCtx') @@ -220,10 +220,8 @@ asEVarApp _ _ = Nothing -- for some universal context @x1:T1, ..., xn:Tn@ and some lists of argument -- expressions @y1, ..., ym@ and @z1, ..., zl@ over the universal context. data CoIndHyp = CoIndHyp { - -- | The uvars that were in scope when this assmption was created, in order - -- from outermost to innermost; that is, the uvars as "seen from outside their - -- scope", which is the reverse of the order of 'mrUVars', below - coIndHypCtx :: [(LocalName,Term)], + -- | The uvars that were in scope when this assmption was created + coIndHypCtx :: MRVarCtx, -- | The LHS function name coIndHypLHSFun :: FunName, -- | The RHS function name @@ -263,10 +261,11 @@ coIndHypSetArgs hyp specs x = -- | Add a variable to the context of a coinductive hypothesis, returning the -- updated coinductive hypothesis and a 'Term' which is the new variable coIndHypWithVar :: CoIndHyp -> LocalName -> Type -> MRM (CoIndHyp, Term) -coIndHypWithVar (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) nm (Type tp) = +coIndHypWithVar (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) nm tp = do var <- liftSC1 scLocalVar 0 + let ctx' = mrVarCtxAppend (singletonMRVarCtx nm tp) ctx (args1', args2') <- liftTermLike 0 1 (args1, args2) - return (CoIndHyp (ctx ++ [(nm,tp)]) f1 f2 args1' args2' invar1 invar2, var) + return (CoIndHyp ctx' f1 f2 args1' args2' invar1 invar2, var) -- | A map from pairs of function names to co-inductive hypotheses over those -- names @@ -274,8 +273,9 @@ type CoIndHyps = Map (FunName, FunName) CoIndHyp instance PrettyInCtx CoIndHyp where prettyInCtx (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) = - local (const $ map fst $ reverse ctx) $ - prettyAppList [return (ppCtx ctx <> "."), + -- ignore whatever context we're in and use `ctx` instead + return $ flip runPPInCtxM ctx $ + prettyAppList [prettyInCtx ctx, return ".", (case invar1 of Just f -> prettyTermApp f args1 Nothing -> return "True"), return "=>", @@ -307,10 +307,8 @@ data MRInfo = MRInfo { mriSC :: SharedContext, -- | SMT timeout for SMT calls made by Mr. Solver mriSMTTimeout :: Maybe Integer, - -- | The current context of universal variables, which are free SAW core - -- variables, in order from innermost to outermost, i.e., where element @0@ - -- corresponds to deBruijn index @0@ - mriUVars :: [(LocalName,Type)], + -- | The current context of universal variables + mriUVars :: MRVarCtx, -- | The top-level Mr Solver environment mriEnv :: MREnv, -- | The current set of co-inductive hypotheses @@ -358,7 +356,7 @@ mrSMTTimeout :: MRM (Maybe Integer) mrSMTTimeout = mriSMTTimeout <$> ask -- | Get the current value of 'mriUVars' -mrUVars :: MRM [(LocalName,Type)] +mrUVars :: MRM MRVarCtx mrUVars = mriUVars <$> ask -- | Get the current function assumptions @@ -396,7 +394,8 @@ runMRM sc timeout env m = do true_tm <- scBool sc True let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout, mriEnv = env, - mriUVars = [], mriCoIndHyps = Map.empty, + mriUVars = emptyMRVarCtx, + mriCoIndHyps = Map.empty, mriAssumptions = true_tm, mriDataTypeAssumps = HashMap.empty } let init_st = MRState { mrsVars = Map.empty } @@ -595,22 +594,21 @@ mrBvToNat n len = liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] -- | Get the current context of uvars as a list of variable names and their -- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in -- the order as seen "from the outside" -mrUVarCtx :: MRM [(LocalName,Term)] -mrUVarCtx = reverse <$> mrUVarCtxRev +mrUVarsOuterToInner :: MRM [(LocalName,Term)] +mrUVarsOuterToInner = mrVarCtxOuterToInner <$> mrUVars -- | Get the current context of uvars as a list of variable names and their -- types as SAW core 'Term's, with the most recently bound uvar first, i.e., in -- the order as seen "from the inside" -mrUVarCtxRev :: MRM [(LocalName,Term)] -mrUVarCtxRev = map (\(nm,Type tp) -> (nm,tp)) <$> mrUVars +mrUVarsInnerToOuter :: MRM [(LocalName,Term)] +mrUVarsInnerToOuter = mrVarCtxInnerToOuter <$> mrUVars -- | Get the type of a 'Term' in the current uvar context mrTypeOf :: Term -> MRM Term mrTypeOf t = - -- NOTE: scTypeOf' wants the type context in the most recently bound var - -- first, i.e., in the mrUVarCtxRev order + -- NOTE: scTypeOf' wants the type context in the most recently bound var first -- mrDebugPPPrefix 3 "mrTypeOf:" t >> - mrUVarCtxRev >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t + mrUVarsInnerToOuter >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t -- | Check if two 'Term's are convertible in the 'MRM' monad mrConvertible :: Term -> Term -> MRM Bool @@ -652,7 +650,9 @@ mrLambdaLift :: TermLike tm => [(LocalName,Term)] -> tm -> ([Term] -> tm -> MRM Term) -> MRM Term mrLambdaLift [] t f = f [] t mrLambdaLift ctx t f = - do nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVars + do -- uniquifyNames doesn't care about the order of the names in its second, + -- argument, thus either inner-to-outer or outer-to-inner would work + nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVarsInnerToOuter let ctx' = zipWith (\nm (_,tp) -> (nm,tp)) nms ctx vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. length ctx - 1] t' <- liftTermLike 0 (length ctx) t @@ -662,7 +662,7 @@ mrLambdaLift ctx t f = -- variable, which is passed as a 'Term' to the sub-computation. Note that any -- assumptions made in the sub-computation will be lost when it completes. withUVar :: LocalName -> Type -> (Term -> MRM a) -> MRM a -withUVar nm (Type tp) m = withUVars [(nm,tp)] (\[v] -> m v) +withUVar nm tp m = withUVars (singletonMRVarCtx nm tp) (\[v] -> m v) -- | Run a MR Solver computation in a context extended with a universal variable -- and pass it the lifting (in the sense of 'incVars') of an MR Solver term @@ -673,17 +673,25 @@ withUVarLift nm tp t m = -- | Run a MR Solver computation in a context extended with a list of universal -- variables, passing 'Term's for those variables to the supplied computation. --- The variables are bound "outside in", meaning the first variable in the list --- is bound outermost, and so will have the highest deBruijn index. -withUVars :: [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a -withUVars [] f = f [] +withUVars :: MRVarCtx -> ([Term] -> MRM a) -> MRM a +withUVars (mrVarCtxLength -> 0) f = f [] withUVars ctx f = - do nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVars - let ctx_u = zip nms $ map (Type . snd) ctx - assumps' <- mrAssumptions >>= liftTerm 0 (length ctx) - dataTypeAssumps' <- mrDataTypeAssumps >>= mapM (liftTermLike 0 (length ctx)) - vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. length ctx - 1] - local (\info -> info { mriUVars = reverse ctx_u ++ mriUVars info, + do -- for uniquifyNames, we want to consider the oldest names first, thus we + -- must pass the first argument in outer-to-inner order. uniquifyNames + -- doesn't care about the order of the names in its second, argument, thus + -- either inner-to-outer or outer-to-inner would work + let ctx_l = mrVarCtxOuterToInner ctx + nms <- uniquifyNames (map fst ctx_l) <$> map fst <$> mrUVarsInnerToOuter + let ctx_u = mrVarCtxFromOuterToInner $ zip nms $ map snd ctx_l + -- lift all the variables in our assumptions by the number of new uvars + -- we're adding (we do not have to lift the types in our uvar context + -- itself, since each type is in the context of all older uvars - see the + -- definition of MRVarCtx) + assumps' <- mrAssumptions >>= liftTerm 0 (mrVarCtxLength ctx) + dataTypeAssumps' <- mrDataTypeAssumps >>= mapM (liftTermLike 0 (mrVarCtxLength ctx)) + -- make terms for our new uvars, extend the context, and continue + vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. mrVarCtxLength ctx - 1] + local (\info -> info { mriUVars = mrVarCtxAppend ctx_u (mriUVars info), mriAssumptions = assumps', mriDataTypeAssumps = dataTypeAssumps' }) $ mrDebugPPPrefix 3 "withUVars:" ctx_u >> @@ -693,35 +701,35 @@ withUVars ctx f = withNoUVars :: MRM a -> MRM a withNoUVars m = do true_tm <- liftSC1 scBool True - local (\info -> info { mriUVars = [], mriAssumptions = true_tm, + local (\info -> info { mriUVars = emptyMRVarCtx, mriAssumptions = true_tm, mriDataTypeAssumps = HashMap.empty }) m -- | Run a MR Solver in a context of only the specified UVars, no others - -- note that this also clears all assumptions -withOnlyUVars :: [(LocalName,Term)] -> MRM a -> MRM a +withOnlyUVars :: MRVarCtx -> MRM a -> MRM a withOnlyUVars vars m = withNoUVars $ withUVars vars $ const m -- | Build 'Term's for all the uvars currently in scope, ordered from least to -- most recently bound getAllUVarTerms :: MRM [Term] getAllUVarTerms = - (length <$> mrUVars) >>= \len -> + (mrVarCtxLength <$> mrUVars) >>= \len -> mapM (liftSC1 scLocalVar) [len-1, len-2 .. 0] -- | Lambda-abstract all the current uvars out of a 'Term', with the least -- recently bound variable being abstracted first lambdaUVarsM :: Term -> MRM Term -lambdaUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scLambdaList ctx t +lambdaUVarsM t = mrUVarsOuterToInner >>= \ctx -> liftSC2 scLambdaList ctx t -- | Pi-abstract all the current uvars out of a 'Term', with the least recently -- bound variable being abstracted first piUVarsM :: Term -> MRM Term -piUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scPiList ctx t +piUVarsM t = mrUVarsOuterToInner >>= \ctx -> liftSC2 scPiList ctx t -- | Instantiate all uvars in a term using the supplied function instantiateUVarsM :: TermLike a => (LocalName -> Term -> MRM Term) -> a -> MRM a instantiateUVarsM f a = - do ctx <- mrUVarCtx + do ctx <- mrUVarsOuterToInner -- Remember: the uvar context is outermost to innermost, so we bind -- variables from left to right, substituting earlier ones into the types -- of later ones, but all substitutions are in reverse order, since @@ -850,15 +858,15 @@ mrFreshEVar nm (Type tp) = mrSetVarInfo var (EVarInfo Nothing) mrVarTerm var --- | Return a fresh sequence of existential variables for a context of variable --- names and types, assuming each variable is free in the types that occur after --- it in the list. Return the new evars all applied to the current uvars. -mrFreshEVars :: [(LocalName,Term)] -> MRM [Term] -mrFreshEVars = helper [] where +-- | Return a fresh sequence of existential variables from a 'MRVarCtx'. +-- Return the new evars all applied to the current uvars. +mrFreshEVars :: MRVarCtx -> MRM [Term] +mrFreshEVars = helper [] . mrVarCtxOuterToInner where -- Return fresh evars for the suffix of a context of variable names and types, -- where the supplied Terms are evars that have already been generated for the -- earlier part of the context, and so must be substituted into the remaining - -- types in the context + -- types in the context. Since we want to make fresh evars for the oldest + -- variables first, the second argument must be in outer-to-inner order. helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] helper evars [] = return evars helper evars ((nm,tp):ctx) = @@ -1019,7 +1027,7 @@ withFunAssump :: FunName -> [Term] -> NormComp -> MRM a -> MRM a withFunAssump fname args rhs m = do k <- CompFunReturn <$> Type <$> mrFunOutType fname args mrDebugPPPrefixSep 1 "withFunAssump" (FunBind fname args k) "|=" rhs - ctx <- mrUVarCtx + ctx <- mrUVars assumps <- mrFunAssumps let assump = FunAssump ctx args (RewriteFunAssump rhs) let assumps' = Map.insert fname assump assumps @@ -1118,13 +1126,11 @@ debugPretty i pp = debugPrint i $ renderSawDoc defaultPPOpts pp -- | Pretty-print an object in the current context if the current debug level is -- at least the supplied 'Int' debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () -debugPrettyInCtx i a = - mrUVars >>= \ctx -> debugPrint i (showInCtx (map fst ctx) a) +debugPrettyInCtx i a = mrUVars >>= \ctx -> debugPrint i (showInCtx ctx a) -- | Pretty-print an object relative to the current context mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc -mrPPInCtx a = - runReader (prettyInCtx a) <$> map fst <$> mrUVars +mrPPInCtx a = runPPInCtxM (prettyInCtx a) <$> mrUVars -- | Pretty-print the result of 'ppWithPrefix' relative to the current uvar -- context to 'stderr' if the debug level is at least the 'Int' provided @@ -1132,7 +1138,7 @@ mrDebugPPPrefix :: PrettyInCtx a => Int -> String -> a -> MRM () mrDebugPPPrefix i pre a = mrUVars >>= \ctx -> debugPretty i $ - flip runReader (map fst ctx) (group <$> nest 2 <$> ppWithPrefix pre a) + runPPInCtxM (group <$> nest 2 <$> ppWithPrefix pre a) ctx -- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar -- context to 'stderr' if the debug level is at least the 'Int' provided @@ -1141,5 +1147,4 @@ mrDebugPPPrefixSep :: (PrettyInCtx a, PrettyInCtx b) => mrDebugPPPrefixSep i pre a1 sp a2 = mrUVars >>= \ctx -> debugPretty i $ - flip runReader (map fst ctx) (group <$> nest 2 <$> - ppWithPrefixSep pre a1 sp a2) + runPPInCtxM (group <$> nest 2 <$> ppWithPrefixSep pre a1 sp a2) ctx diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 9bd6b2d61d..9cba35cad0 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -260,10 +260,10 @@ mrNormTerm t = -- removing those lambdas mrNormOpenTerm :: Term -> MRM Term mrNormOpenTerm body = - do ctx <- mrUVarCtx - fun_term <- liftSC2 scLambdaList ctx body + do length_ctx <- mrVarCtxLength <$> mrUVars + fun_term <- lambdaUVarsM body normed_fun <- mrNormTerm fun_term - return (peel_lambdas (length ctx) normed_fun) + return (peel_lambdas length_ctx normed_fun) where peel_lambdas :: Int -> Term -> Term peel_lambdas 0 t = t @@ -310,8 +310,7 @@ mrProvableRaw prop_term = mrProvable :: Term -> MRM Bool mrProvable (asBool -> Just b) = return b mrProvable bool_tm = - do uvarCtx <- mrUVarCtx - debugPretty 3 $ "mrProvable uvars:" <> ppCtx uvarCtx + do mrUVars >>= mrDebugPPPrefix 3 "mrProvable uvars:" assumps <- mrAssumptions prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue prop_inst <- mrSubstEVars prop >>= instantiateUVarsM instUVar diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index a1a7f3c252..97c6eb0dab 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -547,7 +547,7 @@ proveCoIndHypInvariant hyp = -- restored and the computation is re-run with the widened hypothesis. mrRefinesCoInd :: FunName -> [Term] -> FunName -> [Term] -> MRM () mrRefinesCoInd f1 args1 f2 args2 = - do ctx <- mrUVarCtx + do ctx <- mrUVars preF1 <- mrGetInvariant f1 preF2 <- mrGetInvariant f2 let hyp = CoIndHyp ctx f1 f2 args1 args2 preF1 preF2 @@ -689,7 +689,7 @@ generalizeCoIndHypArgs hyp [(specs1, tp1), (specs2, tp2)] = case matchHet tp1 tp Nothing -> throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) -generalizeCoIndHypArgs _ specs = map fst <$> mrUVars >>= \uvar_ctx -> +generalizeCoIndHypArgs _ specs = mrUVars >>= \uvar_ctx -> -- Being in this case implies we have types @tp1, tp2, tp3@ which are related -- by 'typesHetRelated' but no two of them are convertible. As of the time of -- writing, the only way this could be possible is if the types are pair @@ -1103,8 +1103,7 @@ mrRefinesFunH k vars tps1 t1 tps2@(asPi -> Just (nm2, tp2@(asEq -> Just (asBoolT mrRefinesFunH k vars tps1@(asPi -> Just (nm1, tp1, _)) t1 tps2@(asPi -> Just (nm2, tp2, _)) t2 = - mrUVarCtx >>= \uvarCtx -> - debugPretty 3 ("mrRefinesFunH uvars:" <> ppCtx uvarCtx) >> + mrUVars >>= mrDebugPPPrefix 3 "mrRefinesFunH uvars:" >> mrDebugPPPrefixSep 3 "mrRefinesFunH types" tps1 "|=" tps2 >> mrDebugPPPrefixSep 3 "mrRefinesFunH" t1 "|=" t2 >> case matchHet tp1 tp2 of @@ -1241,14 +1240,14 @@ askMRSolverH f t1 t2 = -- If t1 and t2 are both named functions, our result is the opaque -- FunAssump that forall xs. f1 xs |= f2 xs' (FunBind f1 args1 (CompFunReturn _), FunBind f2 args2 (CompFunReturn _)) -> - mrUVarCtx >>= \uvar_ctx -> + mrUVars >>= \uvar_ctx -> return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, fassumpArgs = args1, fassumpRHS = OpaqueFunAssump f2 args2 }) -- If just t1 is a named function, our result is the rewrite FunAssump -- that forall xs. f1 xs |= m2 (FunBind f1 args1 (CompFunReturn _), _) -> - mrUVarCtx >>= \uvar_ctx -> + mrUVars >>= \uvar_ctx -> return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, fassumpArgs = args1, fassumpRHS = RewriteFunAssump m2 }) diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index 687311e4f1..d7c2c6ffe6 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} @@ -111,6 +113,58 @@ funNameTerm (GlobalName gdef (TermProjRecord fname:projs)) = -- | A term specifically known to be of type @sort i@ for some @i@ newtype Type = Type Term deriving (Generic, Show) +-- | A context of variables, with names and types. To avoid confusion as to +-- how variables are ordered, do not use this type's constructor directly. +-- Instead, use the combinators defined below. +newtype MRVarCtx = MRVarCtx [(LocalName,Type)] deriving (Generic, Show) + +-- | Build an empty context of variables +emptyMRVarCtx :: MRVarCtx +emptyMRVarCtx = MRVarCtx [] + +-- | Build a context with a single variable of the given name and type +singletonMRVarCtx :: LocalName -> Type -> MRVarCtx +singletonMRVarCtx nm tp = MRVarCtx [(nm,tp)] + +-- | Add a context of new variables (the first argument) to an existing context +-- (the second argument). The new variables to add must be in the existing +-- context, i.e. all the types in the first argument must be in the context of +-- the second argument. +mrVarCtxAppend :: MRVarCtx -> MRVarCtx -> MRVarCtx +mrVarCtxAppend (MRVarCtx ctx1) (MRVarCtx ctx2) = MRVarCtx (ctx1 ++ ctx2) + +-- | Return the number of variables in the given context +mrVarCtxLength :: MRVarCtx -> Int +mrVarCtxLength (MRVarCtx ctx) = length ctx + +-- | Return a list of the names and types of the variables in the given +-- context in order from innermost to outermost, i.e., where element @i@ +-- corresponds to deBruijn index @i@, and each type is in the context of +-- all the variables which come after it in the list (i.e. all the variables +-- which come after a type in the list are free in that type). In other words, +-- the list is ordered from newest to oldest variable. +mrVarCtxInnerToOuter :: MRVarCtx -> [(LocalName,Term)] +mrVarCtxInnerToOuter (MRVarCtx ctx) = map (\(nm, Type tp) -> (nm, tp)) ctx + +-- | Build a context of variables from a list of names and types in innermost +-- to outermost order - see 'mrVarCtxInnerToOuter'. +mrVarCtxFromInnerToOuter :: [(LocalName,Term)] -> MRVarCtx +mrVarCtxFromInnerToOuter = MRVarCtx . map (\(nm,tp) -> (nm, Type tp)) + +-- | Return a list of the names and types of the variables in the given +-- context in order from outermost to innermost, i.e., where element @i@ +-- corresponds to deBruijn index @len - i@, and each type is in the context of +-- all the variables which come before it in the list (i.e. all the variables +-- which come before a type in the list are free in that type). In other words, +-- the list is ordered from oldest to newest variable. +mrVarCtxOuterToInner :: MRVarCtx -> [(LocalName,Term)] +mrVarCtxOuterToInner = reverse . mrVarCtxInnerToOuter + +-- | Build a context of variables from a list of names and types in outermost +-- to innermost order - see 'mrVarCtxOuterToInner'. +mrVarCtxFromOuterToInner :: [(LocalName,Term)] -> MRVarCtx +mrVarCtxFromOuterToInner = mrVarCtxFromInnerToOuter . reverse + -- | A Haskell representation of a @CompM@ in "monadic normal form" data NormComp = ReturnM Term -- ^ A term @returnM a x@ @@ -253,10 +307,8 @@ data FunAssumpRHS = OpaqueFunAssump FunName [Term] -- expressions @ei@ over the universal @xj@ variables, and some right-hand side -- computation expression @m@. data FunAssump = FunAssump { - -- | The uvars that were in scope when this assmption was created, in order - -- from outermost to innermost; that is, the uvars as "seen from outside their - -- scope", which is the reverse of the order of 'mrUVars', below - fassumpCtx :: [(LocalName,Term)], + -- | The uvars that were in scope when this assumption was created + fassumpCtx :: MRVarCtx, -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars fassumpArgs :: [Term], -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars @@ -394,7 +446,8 @@ instance TermLike LocalName where liftTermLike _ _ = return substTermLike _ _ = return -deriving instance TermLike Type +deriving anyclass instance TermLike Type +deriving anyclass instance TermLike MRVarCtx deriving instance TermLike NormComp deriving instance TermLike CompFun deriving instance TermLike Comp @@ -404,17 +457,24 @@ deriving instance TermLike Comp -- * Pretty-Printing MR Solver Terms ---------------------------------------------------------------------- --- | The monad for pretty-printing in a context of SAW core variables -type PPInCtxM = Reader [LocalName] +-- | The monad for pretty-printing in a context of SAW core variables. The +-- context is in innermost-to-outermost order, i.e. from newest to oldest +-- variable (see 'mrVarCtxInnerToOuter' for more detail on this ordering). +newtype PPInCtxM a = PPInCtxM (Reader [LocalName] a) + deriving newtype (Functor, Applicative, Monad, + MonadReader [LocalName]) + +-- | Run a 'PPInCtxM' computation in the given 'MRVarCtx' context +runPPInCtxM :: PPInCtxM a -> MRVarCtx -> a +runPPInCtxM (PPInCtxM m) = runReader m . map fst . mrVarCtxInnerToOuter -- | Pretty-print an object in a SAW core context and render to a 'String' -showInCtx :: PrettyInCtx a => [LocalName] -> a -> String -showInCtx ctx a = - renderSawDoc defaultPPOpts $ runReader (prettyInCtx a) ctx +showInCtx :: PrettyInCtx a => MRVarCtx -> a -> String +showInCtx ctx a = renderSawDoc defaultPPOpts $ runPPInCtxM (prettyInCtx a) ctx -- | Pretty-print an object in the empty SAW core context ppInEmptyCtx :: PrettyInCtx a => a -> SawDoc -ppInEmptyCtx a = runReader (prettyInCtx a) [] +ppInEmptyCtx a = runPPInCtxM (prettyInCtx a) emptyMRVarCtx -- | A generic function for pretty-printing an object in a SAW core context of -- locally-bound names @@ -433,17 +493,16 @@ prettyTermApp :: Term -> [Term] -> PPInCtxM SawDoc prettyTermApp f_top args = prettyInCtx $ foldl (\f arg -> Unshared $ App f arg) f_top args --- | FIXME: move this helper function somewhere better... -ppCtx :: [(LocalName,Term)] -> SawDoc -ppCtx = align . sep . helper [] where - helper :: [LocalName] -> [(LocalName,Term)] -> [SawDoc] - helper _ [] = [] - helper ns [(n,tp)] = - [ppTermInCtx defaultPPOpts (n:ns) (Unshared $ LocalVar 0) <> ":" <> - ppTermInCtx defaultPPOpts ns tp] - helper ns ((n,tp):ctx) = - (ppTermInCtx defaultPPOpts (n:ns) (Unshared $ LocalVar 0) <> ":" <> - ppTermInCtx defaultPPOpts ns tp <> ",") : (helper (n:ns) ctx) +instance PrettyInCtx MRVarCtx where + prettyInCtx = return . align . sep . helper [] . mrVarCtxOuterToInner where + helper :: [LocalName] -> [(LocalName,Term)] -> [SawDoc] + helper _ [] = [] + helper ns [(n, tp)] = + [ppTermInCtx defaultPPOpts (n:ns) (Unshared $ LocalVar 0) <> ":" <> + ppTermInCtx defaultPPOpts ns tp] + helper ns ((n, tp):ctx) = + (ppTermInCtx defaultPPOpts (n:ns) (Unshared $ LocalVar 0) <> ":" <> + ppTermInCtx defaultPPOpts ns tp <> ",") : (helper (n:ns) ctx) instance PrettyInCtx SawDoc where prettyInCtx pp = return pp From 997b05098b550e0b3e836e31458dc0520e6c82ec Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Wed, 17 Aug 2022 11:57:44 -0700 Subject: [PATCH 08/12] address comments about MRVarCtx --- src/SAWScript/Prover/MRSolver/Term.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index d7c2c6ffe6..4363ba5ca0 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -116,7 +116,11 @@ newtype Type = Type Term deriving (Generic, Show) -- | A context of variables, with names and types. To avoid confusion as to -- how variables are ordered, do not use this type's constructor directly. -- Instead, use the combinators defined below. -newtype MRVarCtx = MRVarCtx [(LocalName,Type)] deriving (Generic, Show) +newtype MRVarCtx = MRVarCtx [(LocalName,Type)] + -- ^ Internally, we store these names and types in order + -- from innermost to outermost variable, see + -- 'mrVarCtxInnerToOuter' + deriving (Generic, Show) -- | Build an empty context of variables emptyMRVarCtx :: MRVarCtx @@ -447,7 +451,6 @@ instance TermLike LocalName where substTermLike _ _ = return deriving anyclass instance TermLike Type -deriving anyclass instance TermLike MRVarCtx deriving instance TermLike NormComp deriving instance TermLike CompFun deriving instance TermLike Comp From 7d52a80d938fb75330b033187955556541f22193 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Fri, 26 Aug 2022 13:53:08 -0700 Subject: [PATCH 09/12] add genCryM/atCryM and tweak Prims to get split normalizing --- cryptol-saw-core/saw/CryptolM.sawcore | 26 ++- saw-core/src/Verifier/SAW/Simulator/Prims.hs | 6 +- src/SAWScript/Interpreter.hs | 3 +- src/SAWScript/Prover/MRSolver/SMT.hs | 191 +++++++++++++------ src/SAWScript/Prover/MRSolver/Solver.hs | 20 +- 5 files changed, 176 insertions(+), 70 deletions(-) diff --git a/cryptol-saw-core/saw/CryptolM.sawcore b/cryptol-saw-core/saw/CryptolM.sawcore index e85580616a..2bc36ca777 100644 --- a/cryptol-saw-core/saw/CryptolM.sawcore +++ b/cryptol-saw-core/saw/CryptolM.sawcore @@ -331,6 +331,12 @@ PLiteralSeqBoolM = -- Sequences +-- Alternate versions of gen and at to get around the behavior of the default prims +genCryM : (n : Nat) -> (a : sort 0) -> (Nat -> a) -> Vec n a; +genCryM = gen; +atCryM : (n : Nat) -> (a : isort 0) -> Vec n a -> Nat -> a; +atCryM = at; + -- FIXME: a number of the non-monadic versions of these functions contain calls -- to finNumRec, which calls error on non-finite numbers. The monadic versions -- of these, below, should be reimplemented to not contain finNumRec, but to @@ -404,6 +410,15 @@ ecTakeM = ecTake TCInf n (CompM a) s); -} +-- An alternate version of join from Prelude to get around the default Prim +joinCryM : (m n : Nat) + -> (a : isort 0) + -> Vec m (Vec n a) + -> Vec (mulNat m n) a; +joinCryM m n a v = + genCryM (mulNat m n) a (\ (i : Nat) -> + atCryM n a (at m (Vec n a) v (divNat i n)) (modNat i n)); + -- FIXME primitive ecDropM : (m : Num) -> isFinite m -> (n : Num) -> (a : sort 0) -> @@ -419,7 +434,7 @@ ecJoinM = (\ (n:Num) -> (a:isort 0) -> Vec m (mseq n a) -> mseq (tcMul (TCNum m) n) a) -- Case for (TCNum m, TCNum n) - (\ (n:Nat) -> \ (a:isort 0) -> join m n a)) + (\ (n:Nat) -> \ (a:isort 0) -> joinCryM m n a)) -- No case for (TCNum m, TCInf), shoudn't happen (finNumRec (\ (n:Num) -> (a:isort 0) -> Stream (CompM (mseq n a)) -> @@ -438,6 +453,13 @@ ecJoinM = n)); -- No case for (TCInf, TCInf), shouldn't happen +-- An alternate version of split from Prelude to get around the default Prim +splitCryM : (m n : Nat) -> (a : isort 0) -> Vec (mulNat m n) a -> Vec m (Vec n a); +splitCryM m n a v = + genCryM m (Vec n a) (\ (i : Nat) -> + genCryM n a (\ (j : Nat) -> + atCryM (mulNat m n) a v (addNat (mulNat i n) j))); + ecSplitM : (m n : Num) -> (a : sort 0) -> mseq (tcMul m n) a -> mseq m (mseq n a); ecSplitM = @@ -449,7 +471,7 @@ ecSplitM = (\ (n:Num) -> (a:isort 0) -> mseq (tcMul (TCNum m) n) a -> Vec m (mseq n a)) -- Case for (TCNum m, TCNum n) - (\ (n:Nat) -> \ (a:isort 0) -> split m n a)) + (\ (n:Nat) -> \ (a:isort 0) -> splitCryM m n a)) -- No case for (TCNum m, TCInf), shouldn't happen (finNumRec (\ (n:Num) -> (a:isort 0) -> mseq (tcMul TCInf n) a -> diff --git a/saw-core/src/Verifier/SAW/Simulator/Prims.hs b/saw-core/src/Verifier/SAW/Simulator/Prims.hs index e6b07aaf36..444685f616 100644 --- a/saw-core/src/Verifier/SAW/Simulator/Prims.hs +++ b/saw-core/src/Verifier/SAW/Simulator/Prims.hs @@ -63,6 +63,7 @@ import Control.Monad.Fix (MonadFix(mfix)) import Control.Monad.Trans import Control.Monad.Trans.Maybe import Control.Monad.Trans.Except +import Data.Functor import Data.Maybe (fromMaybe) import Data.Bitraversable import Data.Bits @@ -503,7 +504,8 @@ selectV mux maxValue valueFn v = impl len 0 bvNatOp :: (VMonad l, Show (Extra l)) => BasePrims l -> Prim l bvNatOp bp = natFun $ \w -> - strictFun $ \v -> + -- make sure our nat has a size, i.e. that 'natToWord' will not fail + natSizeFun $ either snd VNat <&> \v -> Prim (VWord <$> natToWord bp (fromIntegral w) v) -- FIXME check for overflow on w -- bvToNat : (n : Nat) -> Vec n Bool -> Nat; @@ -568,7 +570,7 @@ natToWord bp w val = VBVToNat xsize v -> do x <- toWord (bpPack bp) v case compare xsize (fromIntegral w) of - GT -> panic $ "natToWord: not enough bits for: " ++ show val + GT -> bpBvSlice bp (xsize - fromIntegral w) (fromIntegral w) x EQ -> return x LT -> -- zero-extend x to width w do pad <- bpBvLit bp (fromIntegral w - xsize) 0 diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index d10529e645..d78bdd0dfd 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -3494,7 +3494,8 @@ primitives = Map.fromList (pureVal mrSolverSetDebug) Experimental [ "Set the debug level for Mr. Solver; 0 = no debug output," - , " 1 = some debug output, 2 = all debug output" ] + , " 1 = basic debug output, 2 = verbose debug output," + , " 3 = all debug output" ] --------------------------------------------------------------------- diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 9cba35cad0..0ab447f2a3 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -2,6 +2,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE TupleSections #-} {- | Module : SAWScript.Prover.MRSolver.SMT @@ -21,6 +22,7 @@ import qualified Data.Vector as V import Numeric.Natural (Natural) import Control.Monad.Except import qualified Control.Exception as X +import Control.Monad.Trans.Maybe import Data.Map (Map) import qualified Data.Map as Map @@ -37,6 +39,9 @@ import qualified Verifier.SAW.Prim as Prim import Verifier.SAW.Simulator.Value import Verifier.SAW.Simulator.TermModel import Verifier.SAW.Simulator.Prims +import Verifier.SAW.Module +import Verifier.SAW.Prelude.Constants +import Verifier.SAW.FiniteValue import SAWScript.Proof (termToProp, propToTerm, prettyProp) import What4.Solver @@ -70,12 +75,18 @@ genBVVecTerm sc n_tm len_tm a_tm f_tm = asGenBVVecTerm :: Recognizer Term (Term, Term, Term, Term) asGenBVVecTerm (asApplyAll -> (isGlobalDef "Prelude.genBVVec" -> Just _, - [n, len, a, - (asLambdaList -> ([_,_], e))])) + [n, len, a, f@(asLambdaList -> ([_,_], e))])) | not $ inBitSet 0 $ looseVars e - = Just (n, len, a, e) + = Just (n, len, a, f) asGenBVVecTerm _ = Nothing +-- | Match a term of the form @genCryM n a f@ +asGenCryMTerm :: Recognizer Term (Term, Term, Term) +asGenCryMTerm (asApplyAll -> (isGlobalDef "CryptolM.genCryM" -> Just _, + [n, a, f])) + = Just (n, a, f) +asGenCryMTerm _ = Nothing + -- | Match a term of the form @genFromBVVec n len a v def m@ asGenFromBVVecTerm :: Recognizer Term (Term, Term, Term, Term, Term, Term) asGenFromBVVecTerm (asApplyAll -> @@ -94,21 +105,51 @@ boolValToTerm sc (VBool (Right b)) = scBool sc b boolValToTerm _ (VExtra (VExtraTerm _tp tm)) = return tm boolValToTerm _ v = error ("boolValToTerm: unexpected value: " ++ show v) --- | An implementation of a primitive function that expects a @genBVVec@ term -primGenBVVec :: SharedContext -> (Term -> TmPrim) -> TmPrim -primGenBVVec sc f = - PrimFilterFun "genBVVecPrim" +-- | An implementation of a primitive function that expects a term of the form +-- @genBVVec n _ a _@ or @genCryM (bvToNat n _) a _@, where @n@ is the second +-- argument, and passes to the continuation the associated function of type +-- @Vec n Bool -> a@ +primGenBVVec :: SharedContext -> Natural -> (Term -> TmPrim) -> TmPrim +primGenBVVec sc n = + PrimFilterFun "primGenBVVec" $ + \case + VExtra (VExtraTerm _ t) -> primGenBVVecFilter sc n t + _ -> mzero + +-- | The filter function for 'primGenBVVec', and one case of 'primGenCryM' +primGenBVVecFilter :: SharedContext -> Natural -> + Term -> MaybeT (EvalM TermModel) Term +primGenBVVecFilter sc n (asGenBVVecTerm -> Just (asNat -> Just n', _, _, f)) | n == n' = lift $ + do i_tp <- join $ scVecType sc <$> scNat sc n <*> scBoolType sc + let err_tm = error "primGenBVVec: unexpected variable occurrence" + i_tm <- scLocalVar sc 0 + body <- scApplyAllBeta sc f [i_tm, err_tm] + scLambda sc "i" i_tp body +primGenBVVecFilter sc n (asGenCryMTerm -> Just (asBvToNat -> Just (asNat -> Just n', _), _, f)) | n == n' = lift $ + do i_tp <- join $ scVecType sc <$> scNat sc n <*> scBoolType sc + i_tm <- scLocalVar sc 0 + body <- scApplyBeta sc f =<< scBvToNat sc n i_tm + scLambda sc "i" i_tp body +primGenBVVecFilter _ _ t = + error $ "primGenBVVec could not handle: " ++ showInCtx emptyMRVarCtx t + +-- | An implementation of a primitive function that expects a term of the form +-- @genCryM _ a _@, @genFromBVVec ... (genBVVec _ _ a _) ...@, or +-- @genFromBVVec ... (genCryM (bvToNat _ _) a _) ...@, and passes to the +-- continuation either @Just n@ and the associated function of type +-- @Vec n Bool -> a@, or @Nothing@ and the associated function of type +-- @Nat -> a@ +primGenCryM :: SharedContext -> (Maybe Natural -> Term -> TmPrim) -> TmPrim +primGenCryM sc = + PrimFilterFun "primGenCryM" (\case - VExtra (VExtraTerm _ (asGenBVVecTerm -> Just (n, _, _, e))) -> - -- Generate the function \i -> [i/1,error/0]e - lift $ - do i_tp <- scBoolType sc >>= scVecType sc n - let err_tm = error "primGenBVVec: unexpected variable occurrence" - i_tm <- scLocalVar sc 0 - body <- instantiateVarList sc 0 [err_tm,i_tm] e - scLambda sc "i" i_tp body - _ -> mzero) - f + VExtra (VExtraTerm _ (asGenCryMTerm -> Just (_, _, f))) -> + return (Nothing, f) + VExtra (VExtraTerm _ (asGenFromBVVecTerm -> Just (asNat -> Just n, _, _, + v, _, _))) -> + (Just n,) <$> primGenBVVecFilter sc n v + _ -> mzero + ) . uncurry -- | An implementation of a primitive function that expects a bitvector term primBVTermFun :: SharedContext -> (Term -> TmPrim) -> TmPrim @@ -126,39 +167,46 @@ primBVTermFun sc = scVectorReduced sc tp tms v -> lift (putStrLn ("primBVTermFun: unhandled value: " ++ show v)) >> mzero --- | A datatype representing either a @genFromBVVec n len _ v _ _@ term or --- a vector literal, the latter being represented as a list of 'Term's -data FromBVVecOrLit = FromBVVec { fromBVVec_n :: Natural - , fromBVVec_len :: Term - , fromBVVec_vec :: Term } - | BVVecLit [Term] - --- | An implementation of a primitive function that expects either a --- @genFromBVVec@ term or a vector literal -primFromBVVecOrLit :: SharedContext -> TValue TermModel -> - (FromBVVecOrLit -> TmPrim) -> TmPrim -primFromBVVecOrLit sc a = +-- | A datatype representing the arguments to @genBVVecFromVec@ which can be +-- normalized: a @genFromBVVec n len _ v _ _@ term, a @genCryM _ _ body@ term, +-- or a vector literal, the lattermost being represented as a list of 'Term's +data BVVecFromVecArg = FromBVVec { fromBVVec_n :: Natural + , fromBVVec_len :: Term + , fromBVVec_vec :: Term } + | GenCryM Term + | BVVecLit [Term] + +-- | An implementation of a primitive function that expects a @genFromBVVec@ +-- term, a @genCryM@ term, or a vector literal +primBVVecFromVecArg :: SharedContext -> TValue TermModel -> + (BVVecFromVecArg -> TmPrim) -> TmPrim +primBVVecFromVecArg sc a = PrimFilterFun "primFromBVVecOrLit" $ \case VExtra (VExtraTerm _ (asGenFromBVVecTerm -> Just (asNat -> Just n, len, _, v, _, _))) -> return $ FromBVVec n len v + VExtra (VExtraTerm _ (asGenCryMTerm -> Just (_, _, body))) -> + return $ GenCryM body VVector vs -> lift $ BVVecLit <$> traverse (readBackValueNoConfig "primFromBVVecOrLit" sc a <=< force) (V.toList vs) _ -> mzero --- | Turn a 'FromBVVecOrLit' into a BVVec term, assuming it has the given +-- | Turn a 'BVVecFromVecArg' into a BVVec term, assuming it has the given -- bit-width (given as both a 'Natural' and a 'Term'), length, and element type -- FIXME: Properly handle empty vector literals -bvVecFromBVVecOrLit :: SharedContext -> Natural -> Term -> Term -> Term -> - FromBVVecOrLit -> IO Term -bvVecFromBVVecOrLit sc n _ len _ (FromBVVec n' len' v) = +bvVecBVVecFromVecArg :: SharedContext -> Natural -> Term -> Term -> Term -> + BVVecFromVecArg -> IO Term +bvVecBVVecFromVecArg sc n _ len _ (FromBVVec n' len' v) = do len_cvt_len' <- scConvertible sc True len len' if n == n' && len_cvt_len' then return v - else error "bvVecFromBVVecOrLit: genFromBVVec type mismatch" -bvVecFromBVVecOrLit sc n n' len a (BVVecLit vs) = + else error "bvVecBVVecFromVecArg: genFromBVVec type mismatch" +bvVecBVVecFromVecArg sc n _ len a (GenCryM body) = + do len' <- scBvToNat sc n len + scGlobalApply sc "CryptolM.genCryM" [len', a, body] +bvVecBVVecFromVecArg sc n n' len a (BVVecLit vs) = do body <- mkBody 0 vs i_tp <- scBitvector sc n var0 <- scLocalVar sc 0 @@ -166,7 +214,7 @@ bvVecFromBVVecOrLit sc n n' len a (BVVecLit vs) = f <- scLambdaList sc [("i", i_tp), ("pf", pf_tp)] body scGlobalApply sc "Prelude.genBVVec" [n', len, a, f] where mkBody :: Integer -> [Term] -> IO Term - mkBody _ [] = error "bvVecFromBVVecOrLit: empty vector" + mkBody _ [] = error "bvVecBVVecFromVecArg: empty vector" mkBody _ [x] = return $ x mkBody i (x:xs) = do var1 <- scLocalVar sc 1 @@ -196,23 +244,29 @@ readBackValueNoConfig err_str sc tv v = -- | Implementations of primitives for normalizing Mr Solver terms smtNormPrims :: SharedContext -> Map Ident TmPrim smtNormPrims sc = Map.fromList - [ -- Don't unfold @genBVVec@ when normalizing + [ -- Don't unfold @genBVVec@ or @genCryM when normalizing ("Prelude.genBVVec", Prim (do tp <- scTypeOfGlobal sc "Prelude.genBVVec" VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> scGlobalDef sc "Prelude.genBVVec") ), - -- Normalize applications of @genBVVecFromVec@ to a @genFromBVVec@ term or - -- a vector literal into the body of the @genFromBVVec@ term or @genBVVec@ - -- of an sequence of @ite@s defined by the literal, respectively + ("CryptolM.genCryM", + Prim (do tp <- scTypeOfGlobal sc "CryptolM.genCryM" + VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> + scGlobalDef sc "CryptolM.genCryM") + ), + -- Normalize applications of @genBVVecFromVec@ to a @genFromBVVec@ term + -- into the body of the @genFromBVVec@ term, a @genCryM@ term into a + -- @genCryM@ term of the new length, or vector literal into a sequence + -- of @ite@s defined by the literal ("Prelude.genBVVecFromVec", - natFun $ \_m -> tvalFun $ \a -> primFromBVVecOrLit sc a $ \eith -> + natFun $ \_m -> tvalFun $ \a -> primBVVecFromVecArg sc a $ \eith -> PrimFun $ \_def -> natFun $ \n -> primBVTermFun sc $ \len -> Prim (do n' <- scNat sc n a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" sc a tp <- scGlobalApply sc "Prelude.BVVec" [n', len, a'] VExtra <$> VExtraTerm (VTyTerm (mkSort 0) tp) <$> - bvVecFromBVVecOrLit sc n n' len a' eith) + bvVecBVVecFromVecArg sc n n' len a' eith) ), -- Don't normalize applications of @genFromBVVec@ ("Prelude.genFromBVVec", @@ -230,12 +284,28 @@ smtNormPrims sc = Map.fromList tm <- scGlobalApply sc "Prelude.genFromBVVec" [n', len', a', v', def', m'] return $ VExtra $ VExtraTerm (VVecType m a) tm) ), - -- Normalize applications of @atBVVec@ to a @genBVVec@ term into an - -- application of the body of the @genBVVec@ term to the index + -- Normalize applications of @atBVVec@ or @atCryM@ to a @genBVVec@ or + -- @genCryM@ term into an application of the body of the term to the index ("Prelude.atBVVec", - PrimFun $ \_n -> PrimFun $ \_len -> tvalFun $ \a -> - primGenBVVec sc $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf -> - Prim (VExtra <$> VExtraTerm a <$> scApplyBeta sc f ix) + natFun $ \n -> PrimFun $ \_len -> tvalFun $ \a -> + primGenBVVec sc n $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf -> + Prim (do tm <- scApplyBeta sc f ix + tm' <- smtNorm sc tm + return $ VExtra $ VExtraTerm a tm') + ), + ("CryptolM.atCryM", + PrimFun $ \_n -> tvalFun $ \a -> + primGenCryM sc $ \nMb f -> PrimStrict $ \ix -> + Prim (do natDT <- scRequireDataType sc preludeNatIdent + let natPN = fmap (const $ VSort (mkSort 0)) (dtPrimName natDT) + let nat_tp = VDataType natPN [] [] + ix' <- readBackValueNoConfig "smtNormPrims (atCryM)" sc nat_tp ix + ix'' <- case nMb of + Nothing -> return ix' + Just n -> scNat sc n >>= \n' -> scBvNat sc n' ix' + tm <- scApplyBeta sc f ix'' + tm' <- smtNorm sc tm + return $ VExtra $ VExtraTerm a tm') ), -- Don't normalize applications of @CompM@ ("Prelude.CompM", @@ -247,14 +317,19 @@ smtNormPrims sc = Map.fromList scGlobalApply sc "Prelude.CompM" [tv_trm])) ] +-- | A version of 'mrNormTerm' in the 'IO' monad, and which does not add any +-- debug output +smtNorm :: SharedContext -> Term -> IO Term +smtNorm sc t = + scGetModuleMap sc >>= \modmap -> + normalizeSharedTerm sc modmap (smtNormPrims sc) Map.empty Set.empty t + -- | Normalize a 'Term' using some Mr Solver specific primitives mrNormTerm :: Term -> MRM Term mrNormTerm t = debugPrint 2 "Normalizing term:" >> debugPrettyInCtx 2 t >> - liftSC0 return >>= \sc -> - liftSC0 scGetModuleMap >>= \modmap -> - liftSC5 normalizeSharedTerm modmap (smtNormPrims sc) Map.empty Set.empty t + liftSC1 smtNorm t -- | Normalize an open term by wrapping it in lambdas, normalizing, and then -- removing those lambdas @@ -300,8 +375,12 @@ mrProvableRaw prop_term = Left msg -> debugPrint 2 ("SMT solver encountered a saw-core error term: " ++ msg) >> return False - Right (Just _, _) -> - debugPrint 2 "SMT solver response: not provable" >> return False + Right (Just cex, _) -> + debugPrint 2 "SMT solver response: not provable, with counterexample:" + >> debugPrint 3 (concatMap (\(x,v) -> + " - " ++ renderSawDoc defaultPPOpts (ppTerm defaultPPOpts (Unshared (FTermF (ExtCns x)))) ++ + " = " ++ renderSawDoc defaultPPOpts (ppFirstOrderValue defaultPPOpts v) ++ "\n") cex) + >> return False Right (Nothing, _) -> debugPrint 2 "SMT solver response: provable" >> return True @@ -538,16 +617,18 @@ mrProveRelH' _ het tp1@(asBVVecType -> Just (n1, len1, tpA1)) throwMRFailure (TypesNotEq (Type tp1) (Type tp2))) >> liftSC0 scBoolType >>= \bool_tp -> liftSC2 scVecType n1 bool_tp >>= \ix_tp -> - withUVarLift "eq_ix" (Type ix_tp) (n1,(len1,(tpA1,(tpA2,(t1,t2))))) $ + withUVarLift "ix" (Type ix_tp) (n1,(len1,(tpA1,(tpA2,(t1,t2))))) $ \ix (n1',(len1',(tpA1',(tpA2',(t1',t2'))))) -> do ix_bound <- liftSC2 scGlobalApply "Prelude.bvult" [n1', ix, len1'] - pf <- liftSC2 scGlobalApply "Prelude.unsafeAssertBVULt" [n1', ix, len1'] + pf_tp <- liftSC1 scEqTrue ix_bound + pf <- mrErrorTerm pf_tp "FIXME" -- FIXME replace this with the below? + -- pf <- liftSC2 scGlobalApply "Prelude.unsafeAssertBVULt" [n1', ix, len1'] t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1', len1', tpA1', t1', ix, pf] t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1', len1', tpA2', t2', ix, pf] cond <- mrProveRelH het tpA1' tpA2' t1_prj t2_prj - extTermInCtx [("eq_ix",ix_tp)] <$> + extTermInCtx [("ix",ix_tp)] <$> liftTermInCtx2 scImplies (TermInCtx [] ix_bound) cond -- For non-BVVec vector types where at least one side is an application of diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 97c6eb0dab..be044eb592 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -585,6 +585,8 @@ matchCoIndHyp :: CoIndHyp -> [Term] -> [Term] -> MRM () matchCoIndHyp hyp args1 args2 = do mrDebugPPPrefix 1 "matchCoIndHyp" hyp (args1', args2') <- instantiateCoIndHyp hyp + mrDebugPPPrefixSep 3 "matchCoIndHyp args" args1 "," args2 + mrDebugPPPrefixSep 3 "matchCoIndHyp args'" args1' "," args2' eqs1 <- zipWithM mrProveEq args1' args1 eqs2 <- zipWithM mrProveEq args2' args2 if and (eqs1 ++ eqs2) then return () else @@ -1140,11 +1142,10 @@ mrRefinesFunH k vars tps1@(asPi -> Just (nm1, tp1, _)) t1 -- other, introduce a BVVec variable and substitute `genBVVecFromVec` of that -- variable on the non-BVVec side -- FIXME: Could we handle the a /= a' case here and in generalizeCoIndHypArgs? - Just (HetBVVecVec (n, len, a) (m, a')) -> - do lenNat <- mrBvToNat n len - ms_are_eq <- mrProveEq m lenNat - as_are_eq <- mrConvertible a a' - if ms_are_eq && as_are_eq then return () else + Just (HetBVVecVec (n, len, a) (m, a2)) -> + do lens_are_eq <- mrProveEq m =<< mrBvToNat n len + as_are_eq <- mrConvertible a a2 + if lens_are_eq && as_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 @@ -1156,11 +1157,10 @@ mrRefinesFunH k vars tps1@(asPi -> Just (nm1, tp1, _)) t1 tps1' <- mrTypeOf t1'' tps2' <- mrTypeOf t2'' >>= liftSC1 scWhnf mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' - Just (HetVecBVVec (m, a') (n, len, a)) -> - do lenNat <- mrBvToNat n len - ms_are_eq <- mrProveEq m lenNat - as_are_eq <- mrConvertible a a' - if ms_are_eq && as_are_eq then return () else + Just (HetVecBVVec (m, a2) (n, len, a)) -> + do lens_are_eq <- mrProveEq m =<< mrBvToNat n len + as_are_eq <- mrConvertible a a2 + if lens_are_eq && as_are_eq then return () else throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 From 1522ed5fa7f2ea841167ce9101359b480f19e273 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Fri, 26 Aug 2022 15:43:34 -0700 Subject: [PATCH 10/12] fix bug in genBVVecFromVec prim --- src/SAWScript/Prover/MRSolver/SMT.hs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 0ab447f2a3..3781ca7fab 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -260,7 +260,7 @@ smtNormPrims sc = Map.fromList -- @genCryM@ term of the new length, or vector literal into a sequence -- of @ite@s defined by the literal ("Prelude.genBVVecFromVec", - natFun $ \_m -> tvalFun $ \a -> primBVVecFromVecArg sc a $ \eith -> + PrimFun $ \_m -> tvalFun $ \a -> primBVVecFromVecArg sc a $ \eith -> PrimFun $ \_def -> natFun $ \n -> primBVTermFun sc $ \len -> Prim (do n' <- scNat sc n a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" sc a @@ -376,11 +376,11 @@ mrProvableRaw prop_term = debugPrint 2 ("SMT solver encountered a saw-core error term: " ++ msg) >> return False Right (Just cex, _) -> - debugPrint 2 "SMT solver response: not provable, with counterexample:" - >> debugPrint 3 (concatMap (\(x,v) -> - " - " ++ renderSawDoc defaultPPOpts (ppTerm defaultPPOpts (Unshared (FTermF (ExtCns x)))) ++ - " = " ++ renderSawDoc defaultPPOpts (ppFirstOrderValue defaultPPOpts v) ++ "\n") cex) - >> return False + debugPrint 2 "SMT solver response: not provable" >> + debugPrint 3 ("Counterexample:" ++ concatMap (\(x,v) -> + "\n - " ++ renderSawDoc defaultPPOpts (ppTerm defaultPPOpts (Unshared (FTermF (ExtCns x)))) ++ + " = " ++ renderSawDoc defaultPPOpts (ppFirstOrderValue defaultPPOpts v)) cex) >> + return False Right (Nothing, _) -> debugPrint 2 "SMT solver response: provable" >> return True From f496f696e60f40bb731caf9514d9dd6a23d85738 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Fri, 9 Sep 2022 15:40:35 -0400 Subject: [PATCH 11/12] replace matchHet with findInjConvs --- heapster-saw/examples/sha512_mr_solver.saw | 10 +- src/SAWScript/Prover/MRSolver/Monad.hs | 47 +-- src/SAWScript/Prover/MRSolver/SMT.hs | 404 +++++++++++++++------ src/SAWScript/Prover/MRSolver/Solver.hs | 350 +++++++----------- src/SAWScript/Prover/MRSolver/Term.hs | 4 + 5 files changed, 430 insertions(+), 385 deletions(-) diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index 911d92f0a2..107a825594 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -27,8 +27,8 @@ monadify_term {{ processBlock_spec }}; monadify_term {{ processBlocks_loop_spec }}; monadify_term {{ processBlocks_spec }}; -mr_solver_set_debug_level 3; -mr_solver_assume round_00_15 {{ round_00_15_spec }}; -mr_solver_assume round_16_80 {{ round_16_80_spec }}; -mr_solver_assume processBlock {{ processBlock_spec }}; -mr_solver_prove processBlocks {{ processBlocks_spec }}; +// mr_solver_set_debug_level 3; +mr_solver_prove round_00_15 {{ round_00_15_spec }}; +// mr_solver_prove round_16_80 {{ round_16_80_spec }}; +// mr_solver_prove processBlock {{ processBlock_spec }}; +// mr_solver_prove processBlocks {{ processBlocks_spec }}; diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 9c751cc30d..44ae0a8765 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -25,7 +25,6 @@ module SAWScript.Prover.MRSolver.Monad where import Data.List (find, findIndex, foldl') import qualified Data.Text as T -import Numeric.Natural (Natural) import System.IO (hPutStrLn, stderr) import Control.Monad.Reader import Control.Monad.State @@ -252,12 +251,6 @@ coIndHypSetArg hyp@(CoIndHyp {..}) (Left i) x = coIndHypSetArg hyp@(CoIndHyp {..}) (Right i) x = hyp { coIndHypRHS = take i coIndHypRHS ++ x : drop (i+1) coIndHypRHS } --- | Set all of the arguments in the given list to the given value in a --- coinductive hypothesis, using 'coIndHypSetArg' -coIndHypSetArgs :: CoIndHyp -> [Either Int Int] -> Term -> CoIndHyp -coIndHypSetArgs hyp specs x = - foldl' (\hyp' spec -> coIndHypSetArg hyp' spec x) hyp specs - -- | Add a variable to the context of a coinductive hypothesis, returning the -- updated coinductive hypothesis and a 'Term' which is the new variable coIndHypWithVar :: CoIndHyp -> LocalName -> Type -> MRM (CoIndHyp, Term) @@ -471,42 +464,6 @@ liftSC5 :: (SharedContext -> a -> b -> c -> d -> e -> IO f) -> liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) ----------------------------------------------------------------------- --- * Relating Types Heterogeneously ----------------------------------------------------------------------- - --- | A datatype encapsulating all the way in which we consider two types to --- be heterogeneously related: either one is a @Num@ and the other is a @Nat@, --- one is a @BVVec@ and the other is a non-@BVVec@ vector (of the same length, --- which must be checked where 'matchHet' is used), or both sides are pairs --- (whose components are respectively heterogeneously related, which must be --- checked where 'matchHet' is used). See 'typesHetRelated' for an example. -data HetRelated = HetBVNum Natural - | HetNumBV Natural - | HetBVVecVec (Term, Term, Term) (Term, Term) - | HetVecBVVec (Term, Term) (Term, Term, Term) - | HetPair (Term, Term) (Term, Term) - --- | Check to see if the given types match one of the cases of 'HetRelated' -matchHet :: Term -> Term -> Maybe HetRelated -matchHet (asBitvectorType -> Just n) - (asDataType -> Just (primName -> "Cryptol.Num", _)) = - Just $ HetBVNum n -matchHet (asDataType -> Just (primName -> "Cryptol.Num", _)) - (asBitvectorType -> Just n) = - Just $ HetNumBV n -matchHet (asBVVecType -> Just (n, len, a)) - (asNonBVVecVectorType -> Just (m, a')) = - Just $ HetBVVecVec (n, len, a) (m, a') -matchHet (asNonBVVecVectorType -> Just (m, a')) - (asBVVecType -> Just (n, len, a)) = - Just $ HetVecBVVec (m, a') (n, len, a) -matchHet (asPairType -> Just (tpL1, tpR1)) - (asPairType -> Just (tpL2, tpR2)) = - Just $ HetPair (tpL1, tpR1) (tpL2, tpR2) -matchHet _ _ = Nothing - - ---------------------------------------------------------------------- -- * Functions for Building Terms ---------------------------------------------------------------------- @@ -542,6 +499,10 @@ mrGenFromBVVec n len a v def_err_str m = -- | Apply a 'TermProj' to perform a projection on a 'Term' doTermProj :: Term -> TermProj -> MRM Term +doTermProj (asPairValue -> Just (t, _)) TermProjLeft = return t +doTermProj (asPairValue -> Just (_, t)) TermProjRight = return t +doTermProj (asRecordValue -> Just t_map) (TermProjRecord fld) + | Just t <- Map.lookup fld t_map = return t doTermProj t TermProjLeft = liftSC1 scPairLeft t doTermProj t TermProjRight = liftSC1 scPairRight t doTermProj t (TermProjRecord fld) = liftSC2 scRecordSelect t fld diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 3781ca7fab..d70c053e9f 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -3,6 +3,10 @@ {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE StandaloneDeriving #-} {- | Module : SAWScript.Prover.MRSolver.SMT @@ -23,6 +27,8 @@ import Numeric.Natural (Natural) import Control.Monad.Except import qualified Control.Exception as X import Control.Monad.Trans.Maybe +import Data.Foldable (foldrM, foldlM) +import GHC.Generics import Data.Map (Map) import qualified Data.Map as Map @@ -419,21 +425,250 @@ mrProvable bool_tm = ---------------------------------------------------------------------- --- * Relating Types Heterogeneously with SMT +-- * Finding injective conversions ---------------------------------------------------------------------- --- | Return true iff the given types are heterogeneously related -typesHetRelated :: Term -> Term -> MRM Bool -typesHetRelated tp1 tp2 = case matchHet tp1 tp2 of - Just (HetBVNum _) -> return True - Just (HetNumBV _) -> return True - Just (HetBVVecVec (n, len, a) (m, a')) -> mrBvToNat n len >>= \m' -> - (&&) <$> mrProveEq m m' <*> typesHetRelated a a' - Just (HetVecBVVec (m, a') (n, len, a)) -> mrBvToNat n len >>= \m' -> - (&&) <$> mrProveEq m m' <*> typesHetRelated a a' - Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) -> - (&&) <$> typesHetRelated tpL1 tpL2 <*> typesHetRelated tpR1 tpR2 - Nothing -> mrConvertible tp1 tp2 +-- | An injection from @Nat@ to @Num@ ('NatToNum'), @Vec n Bool@ to @Nat@ +-- ('BVToNat'), @BVVec n len a@ to @Vec m a@ ('BVVecToVec'), from one pair +-- type to another ('PairToPair'), or any composition of these using '(<>)' +-- (including the composition of none of them, the identity 'NoConv'). This +-- type is primarily used as one of the returns of 'findInjConvs'. +-- NOTE: Do not use the constructors of this type or 'SingleInjConversion' +-- directly, instead use the pattern synonyms mentioned above and '(<>)' to +-- create and compose 'InjConversion's. This ensures elements of this type +-- are always in a normal form w.r.t. 'PairToPair' injections. +newtype InjConversion = ConvComp [SingleInjConversion] + deriving (Generic, Show) + +-- | Used in the implementation of 'InjConversion'. +-- NOTE: Do not use the constructors of this type or 'InjConversion' +-- directly, instead use the pattern synonyms mentioned in the documentation of +-- 'InjConversion' and '(<>)' to create and compose 'InjConversion's. This +-- ensures elements of this type are always in a normal form w.r.t. +-- 'PairToPair' injections. +data SingleInjConversion = SingleNatToNum + | SingleBVToNat Natural + | SingleBVVecToVec Term Term Term Term + | SinglePairToPair InjConversion InjConversion + deriving (Generic, Show) + +deriving instance TermLike SingleInjConversion +deriving instance TermLike InjConversion + +-- | The identity 'InjConversion' +pattern NoConv :: InjConversion +pattern NoConv = ConvComp [] + +-- | The injective conversion from @Nat@ to @Num@ +pattern NatToNum :: InjConversion +pattern NatToNum = ConvComp [SingleNatToNum] + +-- | The injective conversion from @Vec n Bool@ to @Nat@ for a given @n@ +pattern BVToNat :: Natural -> InjConversion +pattern BVToNat n = ConvComp [SingleBVToNat n] + +-- | The injective conversion from @BVVec n len a@ to @Vec m a@ for given +-- @n@, @len@, @a@, and @m@ (in that order), assuming @m >= bvToNat n len@ +pattern BVVecToVec :: Term -> Term -> Term -> Term -> InjConversion +pattern BVVecToVec n len a m = ConvComp [SingleBVVecToVec n len a m] + +-- | An injective conversion from one pair type to another, using the given +-- 'InjConversion's for the first and second projections, respectively +pattern PairToPair :: InjConversion -> InjConversion -> InjConversion +pattern PairToPair c1 c2 <- ConvComp [SinglePairToPair c1 c2] + where PairToPair NoConv NoConv = NoConv + PairToPair c1 c2 = ConvComp [SinglePairToPair c1 c2] + +instance Semigroup InjConversion where + (ConvComp cs1) <> (ConvComp cs2) = ConvComp (cbnPairs $ cs1 ++ cs2) + where cbnPairs :: [SingleInjConversion] -> [SingleInjConversion] + cbnPairs (SinglePairToPair cL1 cR1 : SinglePairToPair cL2 cR2 : cs) = + cbnPairs (SinglePairToPair (cL1 <> cL2) (cR1 <> cR2) : cs) + cbnPairs (s : cs) = s : cbnPairs cs + cbnPairs [] = [] + +instance Monoid InjConversion where + mempty = NoConv + +-- | Return 'True' iff the given 'InjConversion' is not 'NoConv' +nonTrivialConv :: InjConversion -> Bool +nonTrivialConv (ConvComp cs) = not (null cs) + +-- | Return 'True' iff the given 'InjConversion's are convertible, i.e. if +-- the two injective conversions are the compositions of the same constructors, +-- and the arguments to those constructors are convertible via 'mrConvertible' +mrConvsConvertible :: InjConversion -> InjConversion -> MRM Bool +mrConvsConvertible (ConvComp cs1) (ConvComp cs2) = + and <$> zipWithM mrSingleConvsConvertible cs1 cs2 + where mrSingleConvsConvertible :: SingleInjConversion -> SingleInjConversion -> MRM Bool + mrSingleConvsConvertible SingleNatToNum SingleNatToNum = return True + mrSingleConvsConvertible (SingleBVToNat n1) (SingleBVToNat n2) = return $ n1 == n2 + mrSingleConvsConvertible (SingleBVVecToVec n1 len1 a1 m1) + (SingleBVVecToVec n2 len2 a2 m2) = + do ns_are_eq <- mrConvertible n1 n2 + lens_are_eq <- mrConvertible len1 len2 + as_are_eq <- mrConvertible a1 a2 + ms_are_eq <- mrConvertible m1 m2 + return $ ns_are_eq && lens_are_eq && as_are_eq && ms_are_eq + mrSingleConvsConvertible (SinglePairToPair cL1 cR1) + (SinglePairToPair cL2 cR2) = + do cLs_are_eq <- mrConvsConvertible cL1 cL2 + cRs_are_eq <- mrConvsConvertible cR1 cR2 + return $ cLs_are_eq && cRs_are_eq + mrSingleConvsConvertible _ _ = return False + +-- | Apply the given 'InjConversion' to the given term, where compositions +-- @c1 <> c2 <> ... <> cn@ are applied from right to left as in function +-- composition (i.e. @mrApplyConv (c1 <> c2 <> ... <> cn) t@ is equivalent to +-- @mrApplyConv c1 (mrApplyConv c2 (... mrApplyConv cn t ...))@) +mrApplyConv :: InjConversion -> Term -> MRM Term +mrApplyConv (ConvComp cs) = flip (foldrM go) cs + where go :: SingleInjConversion -> Term -> MRM Term + go SingleNatToNum t = liftSC2 scCtorApp "Cryptol.TCNum" [t] + go (SingleBVToNat n) t = liftSC2 scBvToNat n t + go (SingleBVVecToVec n len a m) t = mrGenFromBVVec n len a t "mrApplyConv" m + go (SinglePairToPair c1 c2) t = + do t1 <- mrApplyConv c1 =<< doTermProj t TermProjLeft + t2 <- mrApplyConv c2 =<< doTermProj t TermProjRight + liftSC2 scPairValueReduced t1 t2 + +-- | Try to apply the inverse of the given the conversion to the given term, +-- raising an error if this is not possible - see also 'mrApplyConv' +mrApplyInvConv :: InjConversion -> Term -> MRM Term +mrApplyInvConv (ConvComp cs) = flip (foldlM go) cs + where go :: Term -> SingleInjConversion -> MRM Term + go t SingleNatToNum = case asNum t of + Just (Left t') -> return t' + _ -> error "mrApplyInvConv: Num term does not normalize to TCNum constructor" + go t (SingleBVToNat n) = + do n_tm <- liftSC1 scNat n + liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t] + go t (SingleBVVecToVec n len a m) = + mrGenBVVecFromVec m a t "mrApplyInvConv" n len + go t (SinglePairToPair c1 c2) = + do t1 <- mrApplyInvConv c1 =<< doTermProj t TermProjLeft + t2 <- mrApplyInvConv c2 =<< doTermProj t TermProjRight + liftSC2 scPairValueReduced t1 t2 + +-- | If the given term can be expressed as @mrApplyInvConv c t@ for some @c@ +-- and @t@, return @c@ - otherwise return @NoConv@ +mrConvOfTerm :: Term -> InjConversion +mrConvOfTerm (asNum -> Just (Left t')) = + NatToNum <> mrConvOfTerm t' +mrConvOfTerm (asBvToNat -> Just (asNat -> Just n, t')) = + BVToNat n <> mrConvOfTerm t' +mrConvOfTerm (asGenFromBVVecTerm -> Just (n, len, a, v, _, m)) = + BVVecToVec n len a m <> mrConvOfTerm v +mrConvOfTerm (asPairValue -> Just (t1, t2)) = + PairToPair (mrConvOfTerm t1) (mrConvOfTerm t2) +mrConvOfTerm _ = NoConv + +-- | For two types @tp1@ and @tp2@, and optionally two terms @t1 :: tp1@ and +-- @t2 :: tp2@, tries to find a type @tp@ and 'InjConversion's @c1@ and @c2@ +-- such that @c1@ is an injective conversion from @tp@ to @tp1@ and @c2@ is a +-- injective conversion from @tp@ to @tp2@. This tries to make @c1@ and @c2@ +-- as large as possible, using information from the given terms (i.e. using +-- 'mrConvOfTerm') where possible. In pictorial form, this function finds +-- a @tp@, @c1@, and @c2@ which satisfy the following diagram: +-- +-- > tp1 tp2 +-- > ^ ^ +-- > c1 \ / c2 +-- > \ / +-- > tp +-- +-- Since adding a 'NatToNum' conversion does not require any choice (i.e. +-- unlike 'BVToNat', which requires choosing a bit width), if either @tp1@ or +-- @tp2@ is @Num@, a 'NatToNum' conversion will be included on the respective +-- side. Another subtlety worth noting is the difference between returning +-- @Just (tp, NoConv, NoConv)@ and @Nothing@ - the former indicates that the +-- types @tp1@ and @tp2@ are convertible, but the latter indicates that no +-- 'InjConversion' could be found. +findInjConvs :: Term -> Maybe Term -> Term -> Maybe Term -> + MRM (Maybe (Term, InjConversion, InjConversion)) +-- always add 'NatToNum' conversions +findInjConvs (asDataType -> Just (primName -> "Cryptol.Num", _)) t1 tp2 t2 = + do tp1' <- liftSC0 scNatType + t1' <- mapM (mrApplyInvConv NatToNum) t1 + mb_cs <- findInjConvs tp1' t1' tp2 t2 + return $ fmap (\(tp, c1, c2) -> (tp, NatToNum <> c1, c2)) mb_cs +findInjConvs tp1 t1 (asDataType -> Just (primName -> "Cryptol.Num", _)) t2 = + do tp2' <- liftSC0 scNatType + t2' <- mapM (mrApplyInvConv NatToNum) t2 + mb_cs <- findInjConvs tp1 t1 tp2' t2' + return $ fmap (\(tp, c1, c2) -> (tp, c1, NatToNum <> c2)) mb_cs +-- add a 'BVToNat' conversion if the (optional) given term has a 'BVToNat' +-- conversion +findInjConvs (asNatType -> Just ()) + (Just (asBvToNat -> Just (asNat -> Just n, t1'))) tp2 t2 = + do tp1' <- liftSC1 scBitvector n + mb_cs <- findInjConvs tp1' (Just t1') tp2 t2 + return $ fmap (\(tp, c1, c2) -> (tp, BVToNat n <> c1, c2)) mb_cs +findInjConvs tp1 t1 (asNatType -> Just ()) + (Just (asBvToNat -> Just (asNat -> Just n, t2'))) = + do tp2' <- liftSC1 scBitvector n + mb_cs <- findInjConvs tp1 t1 tp2' (Just t2') + return $ fmap (\(tp, c1, c2) -> (tp, c1, BVToNat n <> c2)) mb_cs +-- add a 'BVToNat' conversion we have a BV on the other side, using the +-- bit-width from the other side +findInjConvs (asNatType -> Just ()) _ (asBitvectorType -> Just n) _ = + do bv_tp <- liftSC1 scBitvector n + return $ Just (bv_tp, BVToNat n, NoConv) +findInjConvs (asBitvectorType -> Just n) _ (asNatType -> Just ()) _ = + do bv_tp <- liftSC1 scBitvector n + return $ Just (bv_tp, NoConv, BVToNat n) +-- add a 'BVVecToVec' conversion if the (optional) given term has a +-- 'BVVecToVec' conversion +findInjConvs (asNonBVVecVectorType -> Just (m, _)) + (Just (asGenFromBVVecTerm -> Just (n, len, a, t1', _, _))) tp2 t2 = + do len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp1' <- liftSC2 scVecType len' a + mb_cs <- findInjConvs tp1' (Just t1') tp2 t2 + return $ fmap (\(tp, c1, c2) -> (tp, BVVecToVec n len a m <> c1, c2)) mb_cs +findInjConvs tp1 t1 (asNonBVVecVectorType -> Just (m, _)) + (Just (asGenFromBVVecTerm -> Just (n, len, a, t2', _, _))) = + do len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp2' <- liftSC2 scVecType len' a + mb_cs <- findInjConvs tp1 t1 tp2' (Just t2') + return $ fmap (\(tp, c1, c2) -> (tp, c1, BVVecToVec n len a m <> c2)) mb_cs +-- add a 'BVVecToVec' conversion we have a BVVec on the other side, using the +-- bit-width from the other side +findInjConvs (asNonBVVecVectorType -> Just (m, a')) _ + (asBVVecType -> Just (n, len, a)) _ = + do bvvec_tp <- liftSC2 scVecType n a + lens_are_eq <- mrProveEq m =<< mrBvToNat n len + as_are_eq <- mrConvertible a a' + if lens_are_eq && as_are_eq + then return $ Just (bvvec_tp, BVVecToVec n len a m, NoConv) + else return $ Nothing +findInjConvs (asBVVecType -> Just (n, len, a)) _ + (asNonBVVecVectorType -> Just (m, a')) _ = + do bvvec_tp <- liftSC2 scVecType n a + lens_are_eq <- mrProveEq m =<< mrBvToNat n len + as_are_eq <- mrConvertible a a' + if lens_are_eq && as_are_eq + then return $ Just (bvvec_tp, NoConv, BVVecToVec n len a m) + else return $ Nothing +-- add a 'pairToPair' conversion if we have pair types on both sides +findInjConvs (asPairType -> Just (tpL1, tpR1)) t1 + (asPairType -> Just (tpL2, tpR2)) t2 = + do tL1 <- mapM (flip doTermProj TermProjLeft ) t1 + tR1 <- mapM (flip doTermProj TermProjRight) t1 + tL2 <- mapM (flip doTermProj TermProjLeft ) t2 + tR2 <- mapM (flip doTermProj TermProjRight) t2 + mb_cLs <- findInjConvs tpL1 tL1 tpL2 tL2 + mb_cRs <- findInjConvs tpR1 tR1 tpR2 tR2 + case (mb_cLs, mb_cRs) of + (Just (tpL, cL1, cL2), Just (tpR, cR1, cR2)) -> + do pair_tp <- liftSC2 scPairType tpL tpR + return $ Just (pair_tp, PairToPair cL1 cR1, PairToPair cL2 cR2) + _ -> return $ Nothing +-- otherwise, just check that the types are convertible +findInjConvs tp1 _ tp2 _ = + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq + then return $ Just (tp1, NoConv, NoConv) + else return $ Nothing ---------------------------------------------------------------------- @@ -458,10 +693,9 @@ mrEq' (asIntegerType -> Just _) t1 t2 = liftSC2 scIntEq t1 t2 mrEq' (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = liftSC3 scBvEq n t1 t2 mrEq' (asDataType -> Just (primName -> "Cryptol.Num", _)) t1 t2 = - liftSC1 scWhnf t1 >>= \t1' -> liftSC1 scWhnf t2 >>= \t2' -> case (t1', t2') of - (asCtor -> Just (primName -> "Cryptol.TCNum", [t1'']), - asCtor -> Just (primName -> "Cryptol.TCNum", [t2''])) -> - liftSC0 scNatType >>= \nat_tp -> mrEq' nat_tp t1'' t2'' + (,) <$> liftSC1 scWhnf t1 <*> liftSC1 scWhnf t2 >>= \case + (asNum -> Just (Left t1'), asNum -> Just (Left t2')) -> + liftSC0 scNatType >>= \nat_tp -> mrEq' nat_tp t1' t2' _ -> error "mrEq': Num terms do not normalize to TCNum constructors" mrEq' _ _ _ = error "mrEq': unsupported type" @@ -593,7 +827,9 @@ mrProveRelH' var_map _ tp1 tp2 t1 (asEVarApp var_map -> Just (evar, args, Nothin mrProveRelH' _ _ (asTupleType -> Just []) (asTupleType -> Just []) _ _ = TermInCtx [] <$> liftSC1 scBool True --- For nat, bitvector, Boolean, and integer types, call mrProveEqSimple +-- For Num, nat, bitvector, Boolean, and integer types, call mrProveEqSimple +mrProveRelH' _ _ _ _ (asNum -> Just (Left t1)) (asNum -> Just (Left t2)) = + mrProveEqSimple (liftSC2 scEqualNat) t1 t2 mrProveRelH' _ _ (asNatType -> Just _) (asNatType -> Just _) t1 t2 = mrProveEqSimple (liftSC2 scEqualNat) t1 t2 mrProveRelH' _ _ tp1@(asVectorType -> Just (n1, asBoolType -> Just ())) @@ -631,107 +867,33 @@ mrProveRelH' _ het tp1@(asBVVecType -> Just (n1, len1, tpA1)) extTermInCtx [("ix",ix_tp)] <$> liftTermInCtx2 scImplies (TermInCtx [] ix_bound) cond --- For non-BVVec vector types where at least one side is an application of --- genFromBVVec, wrap both sides in genBVVecFromVec and recurse -mrProveRelH' _ het tp1@(asNonBVVecVectorType -> Just (m1, tpA1)) - tp2@(asNonBVVecVectorType -> Just (m2, tpA2)) - t1@(asGenFromBVVecTerm -> Just (n, len, _, _, _, _)) t2 = - do ms_are_eq <- mrConvertible m1 m2 - if ms_are_eq then return () else - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) - len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] - tp1' <- liftSC2 scVecType len' tpA1 - tp2' <- liftSC2 scVecType len' tpA2 - t1' <- mrGenBVVecFromVec m1 tpA1 t1 "mrProveRelH (BVVec/BVVec)" n len - t2' <- mrGenBVVecFromVec m2 tpA2 t2 "mrProveRelH (BVVec/BVVec)" n len - mrProveRelH het tp1' tp2' t1' t2' -mrProveRelH' _ het tp1@(asNonBVVecVectorType -> Just (m1, tpA1)) - tp2@(asNonBVVecVectorType -> Just (m2, tpA2)) - t1 t2@(asGenFromBVVecTerm -> Just (n, len, _, _, _, _)) = - do ms_are_eq <- mrConvertible m1 m2 - if ms_are_eq then return () else - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) - len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] - tp1' <- liftSC2 scVecType len' tpA1 - tp2' <- liftSC2 scVecType len' tpA2 - t1' <- mrGenBVVecFromVec m1 tpA1 t1 "mrProveRelH (BVVec/BVVec)" n len - t2' <- mrGenBVVecFromVec m2 tpA2 t2 "mrProveRelH (BVVec/BVVec)" n len - mrProveRelH het tp1' tp2' t1' t2' - -mrProveRelH' _ True tp1 tp2 t1 t2 | Just mh <- matchHet tp1 tp2 = case mh of - - -- If our relation is heterogeneous and we have a bitvector on one side and - -- a Num on the other, ensure that the Num term is TCNum of some Nat, wrap - -- the Nat with bvNat, and recurse - HetBVNum n - | Just (primName -> "Cryptol.TCNum", [t2']) <- asCtor t2 -> - do n_tm <- liftSC1 scNat n - t2'' <- liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t2'] - mrProveRelH True tp1 tp1 t1 t2'' - | otherwise -> throwMRFailure (TermsNotEq t1 t2) - HetNumBV n - | Just (primName -> "Cryptol.TCNum", [t1']) <- asCtor t1 -> - do n_tm <- liftSC1 scNat n - t1'' <- liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t1'] - mrProveRelH True tp1 tp1 t1'' t2 - | otherwise -> throwMRFailure (TermsNotEq t1 t2) - - -- If our relation is heterogeneous and we have a BVVec on one side and a - -- non-BVVec vector on the other, wrap the non-BVVec vector term in - -- genBVVecFromVec and recurse - HetBVVecVec (n, len, _) (m, tpA2) -> - do m' <- mrBvToNat n len - ms_are_eq <- mrProveEq m' m - if ms_are_eq then return () else - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] - tp2' <- liftSC2 scVecType len' tpA2 - t2' <- mrGenBVVecFromVec m tpA2 t2 "mrProveRelH (BVVec/Vec)" n len - -- mrDebugPPPrefixSep 2 "mrProveRelH on BVVec/Vec: " t1 "an`d" t2' - mrProveRelH True tp1 tp2' t1 t2' - HetVecBVVec (m, tpA1) (n, len, _) -> - do m' <- mrBvToNat n len - ms_are_eq <- mrProveEq m' m - if ms_are_eq then return () else - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] - tp1' <- liftSC2 scVecType len' tpA1 - t1' <- mrGenBVVecFromVec m tpA1 t1 "mrProveRelH (Vec/BVVec)" n len - -- mrDebugPPPrefixSep 2 "mrProveRelH on Vec/BVVec: " t1' "and" t2 - mrProveRelH True tp1' tp2 t1' t2 - - -- For pair types, prove both the left and right projections are related - -- (this should be the same as the pair case below - we have to split them - -- up because otherwise GHC 9.0's pattern match checker complains...) - HetPair (tpL1, tpR1) (tpL2, tpR2) -> - do t1L <- liftSC1 scPairLeft t1 - t2L <- liftSC1 scPairLeft t2 - t1R <- liftSC1 scPairRight t1 - t2R <- liftSC1 scPairRight t2 - condL <- mrProveRelH True tpL1 tpL2 t1L t2L - condR <- mrProveRelH True tpR1 tpR2 t1R t2R - liftTermInCtx2 scAnd condL condR - -- For pair types, prove both the left and right projections are related --- (this should be the same as the pair case below - we have to split them --- up because otherwise GHC 9.0's pattern match checker complains...) -mrProveRelH' _ False tp1 tp2 t1 t2 - | Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) <- matchHet tp1 tp2 = - do t1L <- liftSC1 scPairLeft t1 - t2L <- liftSC1 scPairLeft t2 - t1R <- liftSC1 scPairRight t1 - t2R <- liftSC1 scPairRight t2 - condL <- mrProveRelH False tpL1 tpL2 t1L t2L - condR <- mrProveRelH False tpR1 tpR2 t1R t2R - liftTermInCtx2 scAnd condL condR - --- As a fallback, for types we can't handle, just check convertibility -mrProveRelH' _ het tp1 tp2 t1 t2 = - do success <- mrConvertible t1 t2 - tps_eq <- mrConvertible tp1 tp2 - if success then return () else - if het || not tps_eq - then mrDebugPPPrefixSep 2 "mrProveRelH' could not match types: " tp1 "and" tp2 >> - mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 - else mrDebugPPPrefixSep 2 "mrProveEq could not prove convertible: " t1 "and" t2 - TermInCtx [] <$> liftSC1 scBool success +mrProveRelH' _ het (asPairType -> Just (tpL1, tpR1)) + (asPairType -> Just (tpL2, tpR2)) t1 t2 = + do t1L <- liftSC1 scPairLeft t1 + t2L <- liftSC1 scPairLeft t2 + t1R <- liftSC1 scPairRight t1 + t2R <- liftSC1 scPairRight t2 + condL <- mrProveRelH het tpL1 tpL2 t1L t2L + condR <- mrProveRelH het tpR1 tpR2 t1R t2R + liftTermInCtx2 scAnd condL condR + +mrProveRelH' _ het tp1 tp2 t1 t2 = findInjConvs tp1 (Just t1) tp2 (Just t2) >>= \case + -- If we are allowing heterogeneous equality and we can find non-trivial + -- injective conversions from a type @tp@ to @tp1@ and @tp2@, apply the + -- inverses of these conversions to @t1@ and @t2@ and continue checking + -- equality on the results + Just (tp, c1, c2) | het, nonTrivialConv c1 || nonTrivialConv c2 -> do + t1' <- mrApplyInvConv c1 t1 + t2' <- mrApplyInvConv c2 t2 + mrProveRelH True tp tp t1' t2' + -- Otherwise, just check convertibility + _ -> do + success <- mrConvertible t1 t2 + tps_eq <- mrConvertible tp1 tp2 + if success then return () else + if het || not tps_eq + then mrDebugPPPrefixSep 2 "mrProveRelH' could not match types: " tp1 "and" tp2 >> + mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 + else mrDebugPPPrefixSep 2 "mrProveEq could not prove convertible: " t1 "and" t2 + TermInCtx [] <$> liftSC1 scBool success diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index be044eb592..56869fa8aa 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -578,7 +578,6 @@ proveCoIndHyp hyp = withFailureCtx (FailCtxCoIndHyp hyp) $ proveCoIndHyp hyp' e -> throwError e - -- | Test that a coinductive hypothesis for the given function names matches the -- given arguments, otherwise throw an exception saying that widening is needed matchCoIndHyp :: CoIndHyp -> [Term] -> [Term] -> MRM () @@ -594,112 +593,84 @@ matchCoIndHyp hyp args1 args2 = (map Left (findIndices not eqs1) ++ map Right (findIndices not eqs2)) proveCoIndHypInvariant hyp - -- | Generalize some of the arguments of a coinductive hypothesis generalizeCoIndHyp :: CoIndHyp -> [Either Int Int] -> MRM CoIndHyp generalizeCoIndHyp hyp [] = return hyp -generalizeCoIndHyp hyp all_specs@(arg_spec:arg_specs) = +generalizeCoIndHyp hyp all_specs@(arg_spec_0:arg_specs) = withOnlyUVars (coIndHypCtx hyp) $ do withNoUVars $ mrDebugPPPrefixSep 2 "generalizeCoIndHyp with indices" all_specs "on" hyp -- Get the arg and type associated with arg_spec - let arg = coIndHypArg hyp arg_spec - arg_tp <- mrTypeOf arg - -- Sort out the other args that are heterogeneously related to arg - eq_uneq_specs <- forM arg_specs $ \spec' -> - do let arg' = coIndHypArg hyp spec' - tp' <- mrTypeOf arg' - tps_rel <- typesHetRelated arg_tp tp' - args_rel <- if tps_rel then mrProveRel True arg arg' else return False - return $ if args_rel then Left (spec', tp') else Right spec' - let (eq_specs, uneq_specs) = partitionEithers eq_uneq_specs - -- Group the eq_specs by their type, i.e. turn a list @[(Idx, Type)]@ into a - -- list @[([Idx], Type)]@, where all the indices in each pair share the same - -- type (as in 'mrConvertible') - let addArgByTp :: [([a], Term)] -> (a, Term) -> MRM [([a], Term)] - addArgByTp [] (x, tp) = return [([x], tp)] - addArgByTp ((xs, tp):xstps) (x, tp') = - do tps_eq <- mrConvertible tp' tp - if tps_eq then return ((x:xs, tp):xstps) - else ((xs, tp):) <$> addArgByTp xstps (x, tp') - eq_specs_gpd <- foldlM addArgByTp [] ((arg_spec,arg_tp):eq_specs) - -- Add a new variable, set all the indices in @eq_specs_gpd@ to it as in - -- 'generalizeCoIndHypArgs', and recurse - hyp' <- generalizeCoIndHypArgs hyp eq_specs_gpd - generalizeCoIndHyp hyp' uneq_specs - --- | Assuming all the types in the given list are related by 'typesHetRelated' --- and no two of them are convertible, add a new variable and set all of --- indices in the given list to it, modulo possibly some wrapper functions --- determined by how the types are heterogeneously related -generalizeCoIndHypArgs :: CoIndHyp -> [([Either Int Int], Term)] -> MRM CoIndHyp - --- If all the arguments we need to generalize have the same type, introduce a --- new variable and set all of the given arguments to it -generalizeCoIndHypArgs hyp [(specs, tp)] = - do (hyp', var) <- coIndHypWithVar hyp "z" (Type tp) - return $ coIndHypSetArgs hyp' specs var - -generalizeCoIndHypArgs hyp [(specs1, tp1), (specs2, tp2)] = case matchHet tp1 tp2 of - - -- If we need to generalize bitvector arguments with Num arguments, introduce - -- a bitvector variable and set all of the bitvector arguments to it and - -- all of the Num arguments to `TCNum` of `bvToNat` of it - Just (HetBVNum n) -> - do (hyp', var) <- coIndHypWithVar hyp "z" (Type tp1) - nat_tm <- liftSC2 scBvToNat n var - num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - let hyp'' = coIndHypSetArgs hyp' specs1 var - return $ coIndHypSetArgs hyp'' specs2 num_tm - Just (HetNumBV n) -> - do (hyp', var) <- coIndHypWithVar hyp "z" (Type tp2) - nat_tm <- liftSC2 scBvToNat n var - num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - let hyp'' = coIndHypSetArgs hyp' specs1 num_tm - return $ coIndHypSetArgs hyp'' specs2 var - - -- If we need to generalize BVVec arguments with Vec arguments, introduce a - -- BVVec variable and set all of the BVVec arguments to it and all of the - -- Vec arguments to `genBVVecFromVec` of it - -- FIXME: Could we handle the a /= a' case here and in mrRefinesFunH? - Just (HetBVVecVec (n, len, a) (m, a')) -> - do m' <- mrBvToNat n len - ms_are_eq <- mrProveEq m m' - as_are_eq <- mrConvertible a a' - if ms_are_eq && as_are_eq then return () else - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - (hyp', var) <- coIndHypWithVar hyp "z" (Type tp1) - bvv_tm <- mrGenFromBVVec n len a var "generalizeCoIndHypArgs (BVVec/Vec)" m - let hyp'' = coIndHypSetArgs hyp' specs1 var - return $ coIndHypSetArgs hyp'' specs2 bvv_tm - Just (HetVecBVVec (m, a') (n, len, a)) -> - do m' <- mrBvToNat n len - ms_are_eq <- mrProveEq m m' - as_are_eq <- mrConvertible a a' - if ms_are_eq && as_are_eq then return () else - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - (hyp', var) <- coIndHypWithVar hyp "z" (Type tp2) - bvv_tm <- mrGenFromBVVec n len a var "generalizeCoIndHypArgs (Vec/BVVec)" m - let hyp'' = coIndHypSetArgs hyp' specs1 bvv_tm - return $ coIndHypSetArgs hyp'' specs2 var - - -- This case should be unreachable because in 'mrRefinesFunH' we always - -- expand all tuples - though in principle we could handle it - Just (HetPair _ _) -> - debugPrint 0 "generalizeCoIndHypArgs: trying to widen distinct tuple types:" >> - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - - Nothing -> throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - -generalizeCoIndHypArgs _ specs = mrUVars >>= \uvar_ctx -> - -- Being in this case implies we have types @tp1, tp2, tp3@ which are related - -- by 'typesHetRelated' but no two of them are convertible. As of the time of - -- writing, the only way this could be possible is if the types are pair - -- types related in different components (e.g. @(a,b), (a',b), (a,b')@). In - -- 'mrRefinesFunH' we always expand all tuples, so when we hit this function - -- no such types should remain. - error $ "generalizeCoIndHypArgs: too many distinct types to widen: " - ++ showInCtx uvar_ctx specs + let arg_tm_0 = coIndHypArg hyp arg_spec_0 + arg_tp_0 <- mrTypeOf arg_tm_0 + -- Partition @arg_specs@ into a left list (@eq_specs@) and a right list + -- (@uneq_specs@) where an @arg_spec_i@ is put in the left list if + -- 'findInjConvs' returns 'Just' and @arg_tm_0@ and @arg_tm_i@ are related + -- via 'mrProveRel' - i.e. if there exists a type @tp_i@ and 'InjConversion's + -- @c1_i@ and @c2_i@ such that @c1_i@ is an injective conversion from + -- 'tp_i' to 'arg_tp_0', @c2_i@ is an injective conversion from + -- 'tp_i' to 'arg_tp_i', and @arg_tm_0@ and @arg_tm_i@ are convertible when + -- the inverses of @c1_i@ and @c2_i@ are applied. In other words, @eq_specs@ + -- contains all the specs which are equal to @arg_spec_0@ up to some + -- injective conversions. + (eq_specs, uneq_specs) <- fmap partitionEithers $ forM arg_specs $ \arg_spec_i -> + let arg_tm_i = coIndHypArg hyp arg_spec_i in + mrTypeOf arg_tm_i >>= \arg_tp_i -> + findInjConvs arg_tp_0 (Just arg_tm_0) arg_tp_i (Just arg_tm_i) >>= \case + Just cvs -> mrProveRel True arg_tm_0 arg_tm_i >>= \case + True -> return $ Left (arg_spec_i, cvs) + _ -> return $ Right arg_spec_i + _ -> return $ Right arg_spec_i + -- What want to do is generalize all the arg_specs in @eq_specs@ into a + -- single variable (with some appropriate conversions applied). So, what + -- we need to do is find a @tp@ (and appropriate conversions) such that the + -- following diagram holds for all @i@ and @j@ (using the names from the + -- previous comment): + -- + -- > arg_tp_i arg_tp_0 arg_tp_j + -- > ^ ^ ^ ^ + -- > \ / \ / + -- > tp_i tp_j + -- > ^ ^ + -- > \ / + -- > tp + -- + -- To do this, we simply need to call 'findInjConvs' iteratively as we fold + -- through @eq_specs@, and compose the injective conversions appropriately. + -- Each step of this iteration is @cbnConvs@, which can be pictured as: + -- + -- > arg_tp_0 arg_tp_i + -- > ^ ^ ^ + -- > c_0 | c1_i \ / c2_i + -- > | \ / + -- > tp tp_i + -- > ^ ^ + -- > c1 \ / c2 + -- > \ / + -- > tp' + -- + -- where @c1@, @c2@, and @tp'@ come from 'findInjConvs' on @tp@ and @tp_i@, + -- and the @tp@ and @c_0@ to use for the next (@i+1@th) iteration are @tp'@ + -- and @c_0 <> c1@. + let cbnConvs :: (Term, InjConversion, [(a, InjConversion)]) -> + (a, (Term, InjConversion, InjConversion)) -> + MRM (Term, InjConversion, [(a, InjConversion)]) + cbnConvs (tp, c_0, cs) (arg_spec_i, (tp_i, _, c2_i)) = + findInjConvs tp Nothing tp_i Nothing >>= \case + Just (tp', c1, c2) -> + let cs' = fmap (\(spec_j, c_j) -> (spec_j, c_j <> c1)) cs in + return $ (tp', c_0 <> c1, (arg_spec_i, c2_i <> c2) : cs') + Nothing -> error "generalizeCoIndHyp: could not find mutual conversion" + (tp, c_0, eq_specs_cs) <- foldlM cbnConvs (arg_tp_0, NoConv, []) eq_specs + -- Finally we generalize: We add a new variable of type @tp@ and substitute + -- it for all of the arguments in @hyp@ given by @eq_specs@, applying the + -- appropriate conversions from @eq_specs_cs@ + (hyp', var) <- coIndHypWithVar hyp "z" (Type tp) + hyp'' <- foldlM (\hyp_i (arg_spec_i, c_i) -> + coIndHypSetArg hyp_i arg_spec_i <$> mrApplyConv c_i var) + hyp' ((arg_spec_0, c_0) : eq_specs_cs) + -- We finish by recursing on any remaining arg_specs + generalizeCoIndHyp hyp'' uneq_specs ---------------------------------------------------------------------- @@ -903,7 +874,7 @@ mrRefines' (FunBind (LetRecName f) args1 k1) (FunBind (LetRecName f') args2 k2) mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = mrFunOutType f1 args1 >>= \tp1 -> mrFunOutType f2 args2 >>= \tp2 -> - typesHetRelated tp1 tp2 >>= \tps_rel -> + findInjConvs tp1 Nothing tp2 Nothing >>= \mb_convs -> mrFunBodyRecInfo f1 args1 >>= \maybe_f1_body -> mrFunBodyRecInfo f2 args2 >>= \maybe_f2_body -> mrGetCoIndHyp f1 f2 >>= \maybe_coIndHyp -> @@ -923,7 +894,13 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- If we have an opaque FunAssump that f1 args1' refines f2 args2', then -- prove that args1 = args1', args2 = args2', and then that k1 refines k2 (_, Just (FunAssump ctx args1' (OpaqueFunAssump f2' args2'))) | f2 == f2' -> - do evars <- mrFreshEVars ctx + do debugPretty 2 $ flip runPPInCtxM ctx $ + prettyAppList [return "mrRefines using opaque FunAssump:", + prettyInCtx ctx, return ".", + prettyTermApp (funNameTerm f1) args1', + return "|=", + prettyTermApp (funNameTerm f2) args2'] + evars <- mrFreshEVars ctx (args1'', args2'') <- substTermLike 0 evars (args1', args2') zipWithM_ mrAssertProveEq args1'' args1 zipWithM_ mrAssertProveEq args2'' args2 @@ -940,7 +917,17 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- case above, treat either case like we have a rewrite FunAssump and prove -- that args1 = args1' and then that f args refines m2 (_, Just (FunAssump ctx args1' rhs)) -> - do rhs' <- mrFunAssumpRHSAsNormComp rhs + do debugPretty 2 $ flip runPPInCtxM ctx $ + prettyAppList [return "mrRefines rewriting by FunAssump:", + prettyInCtx ctx, return ".", + prettyTermApp (funNameTerm f1) args1', + return "|=", + case rhs of + OpaqueFunAssump f2' args2' -> + prettyTermApp (funNameTerm f2') args2' + RewriteFunAssump rhs_tm -> + prettyInCtx rhs_tm] + rhs' <- mrFunAssumpRHSAsNormComp rhs evars <- mrFreshEVars ctx (args1'', rhs'') <- substTermLike 0 evars (args1', rhs') zipWithM_ mrAssertProveEq args1'' args1 @@ -960,7 +947,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- recursive and have return types which are heterogeneously related, then -- try to coinductively prove that f1 args1 |= f2 args2 under the assumption -- that f1 args1 |= f2 args2, and then try to prove that k1 |= k2 - _ | tps_rel + _ | Just _ <- mb_convs , Just _ <- maybe_f1_body , Just _ <- maybe_f2_body -> mrRefinesCoInd f1 args1 f2 args2 >> mrRefinesFun tp1 k1 tp2 k2 @@ -1055,16 +1042,16 @@ mrRefines'' m1 m2 = throwMRFailure (CompsDoNotRefine m1 m2) mrRefinesFun :: Term -> CompFun -> Term -> CompFun -> MRM () mrRefinesFun tp1 f1 tp2 f2 = do mrDebugPPPrefixSep 1 "mrRefinesFun on types:" tp1 "," tp2 - mrDebugPPPrefixSep 1 "mrRefinesFun" f1 "|=" f2 f1' <- compFunToTerm f1 >>= liftSC1 scWhnf f2' <- compFunToTerm f2 >>= liftSC1 scWhnf + mrDebugPPPrefixSep 1 "mrRefinesFun" f1' "|=" f2' let nm1 = maybe "call_ret_val" id (compFunVarName f1) nm2 = maybe "call_ret_val" id (compFunVarName f2) f1'' <- mrLambdaLift [(nm1, tp1)] f1' $ \[var] -> flip mrApply var f2'' <- mrLambdaLift [(nm2, tp2)] f2' $ \[var] -> flip mrApply var - tps1 <- mrTypeOf f1'' - tps2 <- mrTypeOf f2'' - mrRefinesFunH mrRefines [] tps1 f1'' tps2 f2'' + piTp1 <- mrTypeOf f1'' + piTp2 <- mrTypeOf f2'' + mrRefinesFunH mrRefines [] piTp1 f1'' piTp2 f2'' -- | The main loop of 'mrRefinesFun' and 'askMRSolver': given a continuation, -- two terms of function type, and two equal-length lists representing the @@ -1078,127 +1065,58 @@ mrRefinesFunH :: (Term -> Term -> MRM a) -> [Term] -> Term -> Term -> Term -> Term -> MRM a -- Introduce equalities on either side as assumptions -mrRefinesFunH k vars tps1@(asPi -> Just (nm1, tp1@(asEq -> Just (asBoolType -> Just (), b1, b2)), _)) t1 tps2 t2 = - mrUVars >>= mrDebugPPPrefix 3 "mrRefinesFunH uvars:" >> - mrDebugPPPrefixSep 3 "mrRefinesFunH types" tps1 "|=" tps2 >> - mrDebugPPPrefixSep 3 "mrRefinesFunH" t1 "|=" t2 >> +mrRefinesFunH k vars (asPi -> Just (nm1, tp1@(asEq -> Just (asBoolType -> Just (), b1, b2)), _)) t1 piTp2 t2 = liftSC2 scBoolEq b1 b2 >>= \eq -> withAssumption eq $ let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1] ++ catMaybes [ asLambdaName t1 ] in - withUVarLift nm (Type tp1) (vars, t1, tps2, t2) $ \var (vars', t1', tps2', t2') -> + withUVarLift nm (Type tp1) (vars,t1,piTp2,t2) $ \var (vars',t1',piTp2',t2') -> do t1'' <- mrApplyAll t1' [var] - tps1' <- mrTypeOf t1'' - mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2' -mrRefinesFunH k vars tps1 t1 tps2@(asPi -> Just (nm2, tp2@(asEq -> Just (asBoolType -> Just (), b1, b2)), _)) t2 = - mrUVars >>= mrDebugPPPrefix 3 "mrRefinesFunH uvars:" >> - mrDebugPPPrefixSep 3 "mrRefinesFunH types" tps1 "|=" tps2 >> - mrDebugPPPrefixSep 3 "mrRefinesFunH" t1 "|=" t2 >> + piTp1' <- mrTypeOf t1'' + mrRefinesFunH k (var : vars') piTp1' t1'' piTp2' t2' +mrRefinesFunH k vars piTp1 t1 (asPi -> Just (nm2, tp2@(asEq -> Just (asBoolType -> Just (), b1, b2)), _)) t2 = liftSC2 scBoolEq b1 b2 >>= \eq -> withAssumption eq $ let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm2] ++ catMaybes [ asLambdaName t2 ] in - withUVarLift nm (Type tp2) (vars, tps1, t1, t2) $ \var (vars', tps1', t1', t2') -> + withUVarLift nm (Type tp2) (vars,piTp1,t1,t2) $ \var (vars',piTp1',t1',t2') -> do t2'' <- mrApplyAll t2' [var] - tps2' <- mrTypeOf t2'' - mrRefinesFunH k (var : vars') tps1' t1' tps2' t2'' - -mrRefinesFunH k vars tps1@(asPi -> Just (nm1, tp1, _)) t1 - tps2@(asPi -> Just (nm2, tp2, _)) t2 = - mrUVars >>= mrDebugPPPrefix 3 "mrRefinesFunH uvars:" >> - mrDebugPPPrefixSep 3 "mrRefinesFunH types" tps1 "|=" tps2 >> - mrDebugPPPrefixSep 3 "mrRefinesFunH" t1 "|=" t2 >> - case matchHet tp1 tp2 of - - -- If we need to introduce a bitvector on one side and a Num on the other, - -- introduce a bitvector variable and substitute `TCNum` of `bvToNat` of that - -- variable on the Num side - Just (HetBVNum n) -> - let nm = maybe "_" id $ find ((/=) '_' . Text.head) - $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 - , asLambdaName t2 ] in - withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> - do nat_tm <- liftSC2 scBvToNat n var - num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [num_tm] - tps1' <- mrTypeOf t1'' - tps2' <- mrTypeOf t2'' >>= liftSC1 scWhnf - mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' - Just (HetNumBV n) -> + piTp2' <- mrTypeOf t2'' + mrRefinesFunH k (var : vars') piTp1' t1' piTp2' t2'' + +-- We always curry pair values before introducing them (NOTE: we do this even +-- when the have the same types to ensure we never have to unify a projection +-- of an evar with a non-projected value, e.g. evar.1 == val) +mrRefinesFunH k vars (asPi -> Just (nm1, asPairType -> Just (tpL1, tpR1), _)) t1 + (asPi -> Just (nm2, asPairType -> Just (tpL2, tpR2), _)) t2 = + do t1'' <- mrLambdaLift [(nm1, tpL1), (nm1, tpR1)] t1 $ \[prj1, prj2] t1' -> + liftSC2 scPairValue prj1 prj2 >>= mrApply t1' + t2'' <- mrLambdaLift [(nm2, tpL2), (nm2, tpR2)] t2 $ \[prj1, prj2] t2' -> + liftSC2 scPairValue prj1 prj2 >>= mrApply t2' + piTp1' <- mrTypeOf t1'' + piTp2' <- mrTypeOf t2'' + mrRefinesFunH k vars piTp1' t1'' piTp2' t2'' + +mrRefinesFunH k vars (asPi -> Just (nm1, tp1, _)) t1 + (asPi -> Just (nm2, tp2, _)) t2 = + findInjConvs tp1 Nothing tp2 Nothing >>= \case + -- If we can find injective conversions from from a type @tp@ to @tp1@ and + -- @tp2@, introduce a variable of type @tp@, apply both conversions to it, + -- and substitute the results on the left and right sides, respectively + Just (tp, c1, c2) -> let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 , asLambdaName t2 ] in - withUVarLift nm (Type tp2) (vars, t1, t2) $ \var (vars', t1', t2') -> - do nat_tm <- liftSC2 scBvToNat n var - num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] - t1'' <- mrApplyAll t1' [num_tm] - t2'' <- mrApplyAll t2' [var] - tps1' <- mrTypeOf t1'' >>= liftSC1 scWhnf - tps2' <- mrTypeOf t2'' - mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' - - -- If we need to introduce a BVVec on one side and a non-BVVec vector on the - -- other, introduce a BVVec variable and substitute `genBVVecFromVec` of that - -- variable on the non-BVVec side - -- FIXME: Could we handle the a /= a' case here and in generalizeCoIndHypArgs? - Just (HetBVVecVec (n, len, a) (m, a2)) -> - do lens_are_eq <- mrProveEq m =<< mrBvToNat n len - as_are_eq <- mrConvertible a a2 - if lens_are_eq && as_are_eq then return () else - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - let nm = maybe "_" id $ find ((/=) '_' . Text.head) - $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 - , asLambdaName t2 ] - withUVarLift nm (Type tp1) (vars, n, len, a, m, t1, t2) $ \var (vars', n', len', a', m', t1', t2') -> - do bvv_tm <- mrGenFromBVVec n' len' a' var "mrRefinesFunH (BVVec/Vec)" m' - t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [bvv_tm] - tps1' <- mrTypeOf t1'' - tps2' <- mrTypeOf t2'' >>= liftSC1 scWhnf - mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' - Just (HetVecBVVec (m, a2) (n, len, a)) -> - do lens_are_eq <- mrProveEq m =<< mrBvToNat n len - as_are_eq <- mrConvertible a a2 - if lens_are_eq && as_are_eq then return () else - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - let nm = maybe "_" id $ find ((/=) '_' . Text.head) - $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 - , asLambdaName t2 ] - withUVarLift nm (Type tp2) (vars, n, len, a, m, t1, t2) $ \var (vars', n', len', a', m', t1', t2') -> - do bvv_tm <- mrGenFromBVVec n' len' a' var "mrRefinesFunH (BVVec/Vec)" m' - t1'' <- mrApplyAll t1' [bvv_tm] - t2'' <- mrApplyAll t2' [var] - tps1' <- mrTypeOf t1'' >>= liftSC1 scWhnf - tps2' <- mrTypeOf t2'' - mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' - - -- We always curry pair values before introducing them (NOTE: we do this even - -- when the have the same types to ensure we never have to unify a projection - -- of an evar with a non-projected value, i.e. evar.1 == val ) - Just (HetPair (tpL1, tpR1) (tpL2, tpR2)) -> - do t1'' <- mrLambdaLift [(nm1, tpL1), (nm1, tpR1)] t1 $ \[prj1, prj2] t1' -> - liftSC2 scPairValue prj1 prj2 >>= mrApply t1' - t2'' <- mrLambdaLift [(nm2, tpL2), (nm2, tpR2)] t2 $ \[prj1, prj2] t2' -> - liftSC2 scPairValue prj1 prj2 >>= mrApply t2' - tps1' <- mrTypeOf t1'' - tps2' <- mrTypeOf t2'' - mrRefinesFunH k vars tps1' t1'' tps2' t2'' - - -- Introduce variables of the same type together - Nothing -> - do tps_are_eq <- mrConvertible tp1 tp2 - if tps_are_eq then return () else - throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) - let nm = maybe "_" id $ find ((/=) '_' . Text.head) - $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 - , asLambdaName t2 ] - withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> - do t1'' <- mrApplyAll t1' [var] - t2'' <- mrApplyAll t2' [var] - tps1' <- mrTypeOf t1'' - tps2' <- mrTypeOf t2'' - mrRefinesFunH k (var : vars') tps1' t1'' tps2' t2'' + withUVarLift nm (Type tp) (vars,c1,c2,t1,t2) $ \var (vars',c1',c2',t1',t2') -> + do tm1 <- mrApplyConv c1' var + tm2 <- mrApplyConv c2' var + t1'' <- mrApplyAll t1' [tm1] + t2'' <- mrApplyAll t2' [tm2] + piTp1' <- mrTypeOf t1'' >>= liftSC1 scWhnf + piTp2' <- mrTypeOf t2'' >>= liftSC1 scWhnf + mrRefinesFunH k (var : vars') piTp1' t1'' piTp2' t2'' + -- Otherwise, error + Nothing -> throwMRFailure (TypesNotRel True (Type tp1) (Type tp2)) -- Error if we don't have the same number of arguments on both sides -- FIXME: Add a specific error for this case diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index 4363ba5ca0..adea7de1bf 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -35,6 +35,7 @@ import Data.String import Data.IORef import Control.Monad.Reader import qualified Data.IntMap as IntMap +import Numeric.Natural (Natural) import GHC.Generics import Prettyprinter @@ -449,6 +450,9 @@ instance TermLike FunName where instance TermLike LocalName where liftTermLike _ _ = return substTermLike _ _ = return +instance TermLike Natural where + liftTermLike _ _ = return + substTermLike _ _ = return deriving anyclass instance TermLike Type deriving instance TermLike NormComp From 37565f96c89bd6c8ee1f4318b94d686467ba9f13 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Tue, 20 Sep 2022 17:58:14 -0700 Subject: [PATCH 12/12] fix bugs in new InjConversion interface --- heapster-saw/examples/sha512_mr_solver.saw | 5 +- src/SAWScript/Prover/MRSolver/SMT.hs | 89 ++++++++++++++-------- src/SAWScript/Prover/MRSolver/Solver.hs | 5 +- 3 files changed, 64 insertions(+), 35 deletions(-) diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index 107a825594..bd7ea87192 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -27,8 +27,7 @@ monadify_term {{ processBlock_spec }}; monadify_term {{ processBlocks_loop_spec }}; monadify_term {{ processBlocks_spec }}; -// mr_solver_set_debug_level 3; mr_solver_prove round_00_15 {{ round_00_15_spec }}; -// mr_solver_prove round_16_80 {{ round_16_80_spec }}; -// mr_solver_prove processBlock {{ processBlock_spec }}; +mr_solver_prove round_16_80 {{ round_16_80_spec }}; +mr_solver_prove processBlock {{ processBlock_spec }}; // mr_solver_prove processBlocks {{ processBlocks_spec }}; diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index d70c053e9f..254077bad1 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -405,6 +405,21 @@ mrProvable bool_tm = instUVar :: LocalName -> Term -> MRM Term instUVar nm tp = mrDebugPPPrefix 3 "instUVar" (nm, tp) >> liftSC1 scWhnf tp >>= \case + (asNonBVVecVectorType -> Just (m, a)) -> + liftSC1 smtNorm m >>= \m' -> case asBvToNat m' of + -- For variables of type Vec of length which normalizes to + -- a bvToNat term, recurse and wrap the result in genFromBVVec + Just (n, len) -> do + tp' <- liftSC2 scVecType m' a + tm' <- instUVar nm tp' + mrGenFromBVVec n len a tm' "instUVar" m + -- Otherwise for variables of type Vec, create a @Nat -> a@ + -- function as an ExtCns and apply genBVVec to it + Nothing -> do + nat_tp <- liftSC0 scNatType + tp' <- liftSC3 scPi "_" nat_tp =<< liftTermLike 0 1 a + tm' <- instUVar nm tp' + liftSC2 scGlobalApply "CryptolM.genCryM" [m, a, tm'] -- For variables of type BVVec, create a @Vec n Bool -> a@ function -- as an ExtCns and apply genBVVec to it (asBVVecType -> Just (n, len, a)) -> do @@ -499,23 +514,26 @@ nonTrivialConv (ConvComp cs) = not (null cs) -- and the arguments to those constructors are convertible via 'mrConvertible' mrConvsConvertible :: InjConversion -> InjConversion -> MRM Bool mrConvsConvertible (ConvComp cs1) (ConvComp cs2) = - and <$> zipWithM mrSingleConvsConvertible cs1 cs2 - where mrSingleConvsConvertible :: SingleInjConversion -> SingleInjConversion -> MRM Bool - mrSingleConvsConvertible SingleNatToNum SingleNatToNum = return True - mrSingleConvsConvertible (SingleBVToNat n1) (SingleBVToNat n2) = return $ n1 == n2 - mrSingleConvsConvertible (SingleBVVecToVec n1 len1 a1 m1) - (SingleBVVecToVec n2 len2 a2 m2) = - do ns_are_eq <- mrConvertible n1 n2 - lens_are_eq <- mrConvertible len1 len2 - as_are_eq <- mrConvertible a1 a2 - ms_are_eq <- mrConvertible m1 m2 - return $ ns_are_eq && lens_are_eq && as_are_eq && ms_are_eq - mrSingleConvsConvertible (SinglePairToPair cL1 cR1) - (SinglePairToPair cL2 cR2) = - do cLs_are_eq <- mrConvsConvertible cL1 cL2 - cRs_are_eq <- mrConvsConvertible cR1 cR2 - return $ cLs_are_eq && cRs_are_eq - mrSingleConvsConvertible _ _ = return False + if length cs1 /= length cs2 then return False + else and <$> zipWithM mrSingleConvsConvertible cs1 cs2 + +-- | Used in the definition of 'mrConvsConvertible' +mrSingleConvsConvertible :: SingleInjConversion -> SingleInjConversion -> MRM Bool +mrSingleConvsConvertible SingleNatToNum SingleNatToNum = return True +mrSingleConvsConvertible (SingleBVToNat n1) (SingleBVToNat n2) = return $ n1 == n2 +mrSingleConvsConvertible (SingleBVVecToVec n1 len1 a1 m1) + (SingleBVVecToVec n2 len2 a2 m2) = + do ns_are_eq <- mrConvertible n1 n2 + lens_are_eq <- mrConvertible len1 len2 + as_are_eq <- mrConvertible a1 a2 + ms_are_eq <- mrConvertible m1 m2 + return $ ns_are_eq && lens_are_eq && as_are_eq && ms_are_eq +mrSingleConvsConvertible (SinglePairToPair cL1 cR1) + (SinglePairToPair cL2 cR2) = + do cLs_are_eq <- mrConvsConvertible cL1 cL2 + cRs_are_eq <- mrConvsConvertible cR1 cR2 + return $ cLs_are_eq && cRs_are_eq +mrSingleConvsConvertible _ _ = return False -- | Apply the given 'InjConversion' to the given term, where compositions -- @c1 <> c2 <> ... <> cn@ are applied from right to left as in function @@ -540,11 +558,16 @@ mrApplyInvConv (ConvComp cs) = flip (foldlM go) cs go t SingleNatToNum = case asNum t of Just (Left t') -> return t' _ -> error "mrApplyInvConv: Num term does not normalize to TCNum constructor" - go t (SingleBVToNat n) = - do n_tm <- liftSC1 scNat n - liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t] - go t (SingleBVVecToVec n len a m) = - mrGenBVVecFromVec m a t "mrApplyInvConv" n len + go t (SingleBVToNat n) = case asBvToNat t of + Just (asNat -> Just n', t') | n == n' -> return t' + _ -> do n_tm <- liftSC1 scNat n + liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t] + go t c@(SingleBVVecToVec n len a m) = case asGenFromBVVecTerm t of + Just (n', len', a', t', _, m') -> + do eq <- mrSingleConvsConvertible c (SingleBVVecToVec n' len' a' m') + if eq then return t' + else mrGenBVVecFromVec m a t "mrApplyInvConv" n len + _ -> mrGenBVVecFromVec m a t "mrApplyInvConv" n len go t (SinglePairToPair c1 c2) = do t1 <- mrApplyInvConv c1 =<< doTermProj t TermProjLeft t2 <- mrApplyInvConv c2 =<< doTermProj t TermProjRight @@ -635,16 +658,18 @@ findInjConvs tp1 t1 (asNonBVVecVectorType -> Just (m, _)) -- bit-width from the other side findInjConvs (asNonBVVecVectorType -> Just (m, a')) _ (asBVVecType -> Just (n, len, a)) _ = - do bvvec_tp <- liftSC2 scVecType n a - lens_are_eq <- mrProveEq m =<< mrBvToNat n len + do len_nat <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + bvvec_tp <- liftSC2 scVecType len_nat a + lens_are_eq <- mrProveEq m len_nat as_are_eq <- mrConvertible a a' if lens_are_eq && as_are_eq then return $ Just (bvvec_tp, BVVecToVec n len a m, NoConv) else return $ Nothing findInjConvs (asBVVecType -> Just (n, len, a)) _ (asNonBVVecVectorType -> Just (m, a')) _ = - do bvvec_tp <- liftSC2 scVecType n a - lens_are_eq <- mrProveEq m =<< mrBvToNat n len + do len_nat <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + bvvec_tp <- liftSC2 scVecType len_nat a + lens_are_eq <- mrProveEq m len_nat as_are_eq <- mrConvertible a a' if lens_are_eq && as_are_eq then return $ Just (bvvec_tp, NoConv, BVVecToVec n len a m) @@ -761,10 +786,12 @@ mrProveRel het t1 t2 = mrDebugPPPrefixSep 2 nm t1 (if het then "~=" else "==") t2 tp1 <- mrTypeOf t1 >>= mrSubstEVars tp2 <- mrTypeOf t2 >>= mrSubstEVars - cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2 - res <- withTermInCtx cond_in_ctx mrProvable - debugPrint 2 $ nm ++ ": " ++ if res then "Success" else "Failure" - return res + tps_eq <- mrConvertible tp1 tp2 + if not het && not tps_eq then return False + else do cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2 + res <- withTermInCtx cond_in_ctx mrProvable + debugPrint 2 $ nm ++ ": " ++ if res then "Success" else "Failure" + return res -- | Prove that two terms are related, heterogeneously iff the first argument, -- is true, instantiating evars if necessary, or throwing an error if this is @@ -883,7 +910,7 @@ mrProveRelH' _ het tp1 tp2 t1 t2 = findInjConvs tp1 (Just t1) tp2 (Just t2) >>= -- injective conversions from a type @tp@ to @tp1@ and @tp2@, apply the -- inverses of these conversions to @t1@ and @t2@ and continue checking -- equality on the results - Just (tp, c1, c2) | het, nonTrivialConv c1 || nonTrivialConv c2 -> do + Just (tp, c1, c2) | nonTrivialConv c1 || nonTrivialConv c2 -> do t1' <- mrApplyInvConv c1 t1 t2' <- mrApplyInvConv c2 t2 mrProveRelH True tp tp t1' t2' diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 56869fa8aa..3ed6866d1c 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -1104,6 +1104,8 @@ mrRefinesFunH k vars (asPi -> Just (nm1, tp1, _)) t1 -- @tp2@, introduce a variable of type @tp@, apply both conversions to it, -- and substitute the results on the left and right sides, respectively Just (tp, c1, c2) -> + mrDebugPPPrefixSep 3 "mrRefinesFunH calling findInjConvs" tp1 "," tp2 >> + mrDebugPPPrefix 3 "mrRefinesFunH got type" tp >> let nm = maybe "_" id $ find ((/=) '_' . Text.head) $ [nm1, nm2] ++ catMaybes [ asLambdaName t1 , asLambdaName t2 ] in @@ -1151,7 +1153,8 @@ type MRSolverResult = Maybe (FunName, FunAssump) askMRSolverH :: (NormComp -> NormComp -> MRM ()) -> Term -> Term -> MRM MRSolverResult askMRSolverH f t1 t2 = - do m1 <- normCompTerm t1 + do mrUVars >>= mrDebugPPPrefix 1 "askMRSolverH uvars:" + m1 <- normCompTerm t1 m2 <- normCompTerm t2 f m1 m2 case (m1, m2) of