From 39859fd19badc7686b3d265bc166d4f5a463d33c Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Tue, 18 Apr 2023 13:55:57 +0200 Subject: [PATCH 1/6] Simplify secure forgetting Changes: - Remove MonadSodium typeclass - Express mlocked memory operations in terms of MonadST - Get rid of ForgetMock testing (has been moot since we no longer depend on GC to reclaim mlocked memory) - Remove `m` parameter from mlocked memory based typeclasses (KES, DSIGN) - Simplify KES and DSIGN typeclass hierarchies --- .../cardano-crypto-class.cabal | 7 +- .../src/Cardano/Crypto/DSIGN/Ed25519ML.hs | 72 +-- .../src/Cardano/Crypto/DSIGNM/Class.hs | 86 +++- .../src/Cardano/Crypto/{MEqOrd.hs => EqST.hs} | 39 +- .../src/Cardano/Crypto/KES/Class.hs | 135 ++++-- .../src/Cardano/Crypto/KES/CompactSingle.hs | 18 +- .../src/Cardano/Crypto/KES/CompactSum.hs | 75 ++- .../src/Cardano/Crypto/KES/Mock.hs | 20 +- .../src/Cardano/Crypto/KES/NeverUsed.hs | 14 +- .../src/Cardano/Crypto/KES/Simple.hs | 34 +- .../src/Cardano/Crypto/KES/Single.hs | 20 +- .../src/Cardano/Crypto/KES/Sum.hs | 72 ++- .../src/Cardano/Crypto/Libsodium.hs | 40 +- .../src/Cardano/Crypto/Libsodium/Hash.hs | 21 +- .../Cardano/Crypto/Libsodium/MLockedBytes.hs | 6 + .../Crypto/Libsodium/MLockedBytes/Internal.hs | 92 ++-- .../Cardano/Crypto/Libsodium/MLockedSeed.hs | 84 ++++ .../src/Cardano/Crypto/Libsodium/Memory.hs | 25 + .../Crypto/Libsodium/Memory/Internal.hs | 230 ++++++++-- .../src/Cardano/Crypto/MLockedSeed.hs | 64 --- .../src/Cardano/Crypto/MonadSodium.hs | 64 --- .../src/Cardano/Crypto/MonadSodium/Alloc.hs | 77 ---- .../src/Cardano/Crypto/MonadSodium/Class.hs | 62 --- .../src/Cardano/Crypto/PinnedSizedBytes.hs | 37 +- .../cardano-crypto-tests.cabal | 2 - cardano-crypto-tests/src/Bench/Crypto/KES.hs | 16 +- .../src/Cardano/Crypto/KES/ForgetMock.hs | 169 ------- .../src/Test/Crypto/AllocLog.hs | 105 +---- cardano-crypto-tests/src/Test/Crypto/DSIGN.hs | 56 +-- .../src/Test/Crypto/Instances.hs | 18 +- cardano-crypto-tests/src/Test/Crypto/KES.hs | 428 ++++++------------ cardano-crypto-tests/src/Test/Crypto/Util.hs | 4 +- cardano-mempool/src/Cardano/Memory/Pool.hs | 249 +++++----- .../tests/Test/Cardano/Memory/PoolTests.hs | 16 +- 34 files changed, 1125 insertions(+), 1332 deletions(-) rename cardano-crypto-class/src/Cardano/Crypto/{MEqOrd.hs => EqST.hs} (54%) create mode 100644 cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs delete mode 100644 cardano-crypto-class/src/Cardano/Crypto/MLockedSeed.hs delete mode 100644 cardano-crypto-class/src/Cardano/Crypto/MonadSodium.hs delete mode 100644 cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Alloc.hs delete mode 100644 cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Class.hs delete mode 100644 cardano-crypto-tests/src/Cardano/Crypto/KES/ForgetMock.hs diff --git a/cardano-crypto-class/cardano-crypto-class.cabal b/cardano-crypto-class/cardano-crypto-class.cabal index cf9b91db1..4451c600e 100644 --- a/cardano-crypto-class/cardano-crypto-class.cabal +++ b/cardano-crypto-class/cardano-crypto-class.cabal @@ -77,12 +77,9 @@ library Cardano.Crypto.Libsodium.Memory.Internal Cardano.Crypto.Libsodium.MLockedBytes Cardano.Crypto.Libsodium.MLockedBytes.Internal + Cardano.Crypto.Libsodium.MLockedSeed Cardano.Crypto.Libsodium.UnsafeC - Cardano.Crypto.MEqOrd - Cardano.Crypto.MLockedSeed - Cardano.Crypto.MonadSodium - Cardano.Crypto.MonadSodium.Class - Cardano.Crypto.MonadSodium.Alloc + Cardano.Crypto.EqST Cardano.Crypto.PinnedSizedBytes Cardano.Crypto.Seed Cardano.Crypto.Util diff --git a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs index 923264a55..a4cbc2e2f 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs @@ -39,22 +39,29 @@ import Control.Monad.ST.Unsafe (unsafeIOToST) import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Foreign -import Cardano.Crypto.PinnedSizedBytes import Cardano.Crypto.Libsodium.C -import Cardano.Crypto.Libsodium (MLockedSizedBytes) -import Cardano.Crypto.MonadSodium - ( MonadSodium (..) +import Cardano.Crypto.Libsodium + ( MLockedSizedBytes , mlsbToByteString - , mlsbFromByteStringCheck + , mlsbFromByteStringCheckWith , mlsbUseAsSizedPtr - , mlsbNew + , mlsbNewWith , mlsbFinalize - , mlsbCopy - , MEq (..) + , mlsbCopyWith + ) +import Cardano.Crypto.PinnedSizedBytes + ( PinnedSizedBytes + , psbUseAsSizedPtr + , psbToByteString + , psbFromByteStringCheck + , psbCreateSizedResult + ) +import Cardano.Crypto.EqST + ( EqST (..) ) import Cardano.Crypto.DSIGNM.Class -import Cardano.Crypto.MLockedSeed +import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Util (SignableRepresentation(..)) data Ed25519DSIGNM @@ -83,8 +90,8 @@ cOrError action = do else Just <$> unsafeIOToST getErrno --- | Throws an appropriate 'IOException' when 'Just' an 'Errno' is given. -throwOnErrno :: (MonadThrow m) => String -> String -> Maybe Errno -> m () +-- | Throws an error when 'Just' an 'Errno' is given. +throwOnErrno :: MonadThrow m => String -> String -> Maybe Errno -> m () throwOnErrno contextDesc cFunName maybeErrno = do case maybeErrno of Just errno -> throwIO $ errnoToIOError (contextDesc ++ ": " ++ cFunName) errno Nothing Nothing @@ -171,7 +178,7 @@ instance DSIGNMAlgorithmBase Ed25519DSIGNM where -- reflects this. -- -- Various libsodium primitives, particularly 'MLockedSizedBytes' primitives, --- are used via the 'MonadSodium' typeclass, which is responsible for +-- are used via the 'MonadST' typeclass, which is responsible for -- guaranteeing orderly execution of these actions. We avoid using these -- primitives inside 'unsafeIOToST', as well as any 'IO' actions that would be -- unsafe to use inside 'unsafePerformIO'. @@ -186,14 +193,13 @@ instance DSIGNMAlgorithmBase Ed25519DSIGNM where -- memory passed to them via C pointers. -- - 'getErrno'; however, 'ST' guarantees sequentiality in the context where -- we use 'getErrno', so this is fine. --- - 'BS.useAsCStringLen', which is fine and shouldn't require 'IO' to begin --- with, but unfortunately, for historical reasons, does. -instance (MonadST m, MonadSodium m, MonadThrow m) => DSIGNMAlgorithm m Ed25519DSIGNM where +instance DSIGNMAlgorithm Ed25519DSIGNM where deriveVerKeyDSIGNM (SignKeyEd25519DSIGNM sk) = VerKeyEd25519DSIGNM <$!> do mlsbUseAsSizedPtr sk $ \skPtr -> do - (psb, maybeErrno) <- withLiftST $ \fromST -> fromST $ do - psbCreateSizedResult $ \pkPtr -> + (psb, maybeErrno) <- + psbCreateSizedResult $ \pkPtr -> + withLiftST $ \fromST -> fromST $ do cOrError $ unsafeIOToST $ c_crypto_sign_ed25519_sk_to_pk pkPtr skPtr throwOnErrno "deriveVerKeyDSIGNM @Ed25519DSIGNM" "c_crypto_sign_ed25519_sk_to_pk" maybeErrno @@ -204,8 +210,9 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => DSIGNMAlgorithm m Ed25519DS let bs = getSignableRepresentation a in SigEd25519DSIGNM <$!> do mlsbUseAsSizedPtr sk $ \skPtr -> do - (psb, maybeErrno) <- withLiftST $ \fromST -> fromST $ do - psbCreateSizedResult $ \sigPtr -> do + (psb, maybeErrno) <- + psbCreateSizedResult $ \sigPtr -> do + withLiftST $ \fromST -> fromST $ do cOrError $ unsafeIOToST $ do BS.useAsCStringLen bs $ \(ptr, len) -> c_crypto_sign_ed25519_detached sigPtr nullPtr (castPtr ptr) (fromIntegral len) skPtr @@ -215,9 +222,9 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => DSIGNMAlgorithm m Ed25519DS -- -- Key generation -- - {-# NOINLINE genKeyDSIGNM #-} - genKeyDSIGNM seed = SignKeyEd25519DSIGNM <$!> do - sk <- mlsbNew + {-# NOINLINE genKeyDSIGNMWith #-} + genKeyDSIGNMWith allocator seed = SignKeyEd25519DSIGNM <$!> do + sk <- mlsbNewWith allocator mlsbUseAsSizedPtr sk $ \skPtr -> mlockedSeedUseAsCPtr seed $ \seedPtr -> do maybeErrno <- withLiftST $ \fromST -> @@ -230,11 +237,11 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => DSIGNMAlgorithm m Ed25519DS allocaSizedST k = unsafeIOToST $ allocaSized $ \ptr -> stToIO $ k ptr - cloneKeyDSIGNM (SignKeyEd25519DSIGNM sk) = - SignKeyEd25519DSIGNM <$!> mlsbCopy sk + cloneKeyDSIGNMWith allocator (SignKeyEd25519DSIGNM sk) = + SignKeyEd25519DSIGNM <$!> mlsbCopyWith allocator sk - getSeedDSIGNM _ (SignKeyEd25519DSIGNM sk) = do - seed <- mlockedSeedNew + getSeedDSIGNMWith allocator _ (SignKeyEd25519DSIGNM sk) = do + seed <- mlockedSeedNewWith allocator mlsbUseAsSizedPtr sk $ \skPtr -> mlockedSeedUseAsSizedPtr seed $ \seedPtr -> do maybeErrno <- withLiftST $ \fromST -> @@ -247,13 +254,12 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => DSIGNMAlgorithm m Ed25519DS -- -- Secure forgetting -- - forgetSignKeyDSIGNM (SignKeyEd25519DSIGNM sk) = do - mlsbFinalize sk + forgetSignKeyDSIGNMWith _ (SignKeyEd25519DSIGNM sk) = mlsbFinalize sk deriving via (MLockedSizedBytes (SizeSignKeyDSIGNM Ed25519DSIGNM)) - instance (MonadST m, MonadSodium m) => MEq m (SignKeyDSIGNM Ed25519DSIGNM) + instance EqST (SignKeyDSIGNM Ed25519DSIGNM) -instance (MonadST m, MonadSodium m, MonadThrow m) => UnsoundDSIGNMAlgorithm m Ed25519DSIGNM where +instance UnsoundDSIGNMAlgorithm Ed25519DSIGNM where -- -- Ser/deser (dangerous - do not use in production code) -- @@ -266,12 +272,12 @@ instance (MonadST m, MonadSodium m, MonadThrow m) => UnsoundDSIGNMAlgorithm m Ed mlockedSeedFinalize seed return raw - rawDeserialiseSignKeyDSIGNM raw = do - mseed <- fmap MLockedSeed <$> mlsbFromByteStringCheck raw + rawDeserialiseSignKeyDSIGNMWith allocator raw = do + mseed <- fmap MLockedSeed <$> mlsbFromByteStringCheckWith allocator raw case mseed of Nothing -> return Nothing Just seed -> do - sk <- Just <$> genKeyDSIGNM seed + sk <- Just <$> genKeyDSIGNMWith allocator seed mlockedSeedFinalize seed return sk diff --git a/cardano-crypto-class/src/Cardano/Crypto/DSIGNM/Class.hs b/cardano-crypto-class/src/Cardano/Crypto/DSIGNM/Class.hs index d53020645..68c1fe580 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DSIGNM/Class.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DSIGNM/Class.hs @@ -2,7 +2,6 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -11,6 +10,7 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} -- | Abstract digital signatures. module Cardano.Crypto.DSIGNM.Class @@ -23,6 +23,10 @@ module Cardano.Crypto.DSIGNM.Class , sizeVerKeyDSIGNM , sizeSignKeyDSIGNM , sizeSigDSIGNM + , genKeyDSIGNM + , cloneKeyDSIGNM + , getSeedDSIGNM + , forgetSignKeyDSIGNM -- * 'SignedDSIGNM' wrapper , SignedDSIGNM (..) @@ -46,6 +50,7 @@ module Cardano.Crypto.DSIGNM.Class , UnsoundDSIGNMAlgorithm (..) , encodeSignKeyDSIGNM , decodeSignKeyDSIGNM + , rawDeserialiseSignKeyDSIGNM ) where @@ -56,14 +61,17 @@ import Data.Proxy (Proxy(..)) import Data.Typeable (Typeable) import GHC.Exts (Constraint) import GHC.Generics (Generic) -import GHC.Stack +import GHC.Stack (HasCallStack) import GHC.TypeLits (KnownNat, Nat, natVal, TypeError, ErrorMessage (..)) import NoThunks.Class (NoThunks) +import Control.Monad.Class.MonadST (MonadST) +import Control.Monad.Class.MonadThrow (MonadThrow) import Cardano.Binary (Decoder, decodeBytes, Encoding, encodeBytes, Size, withWordSize) import Cardano.Crypto.Util (Empty) -import Cardano.Crypto.MLockedSeed +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.Libsodium (MLockedAllocator, mlockedMalloc) import Cardano.Crypto.Hash.Class (HashAlgorithm, Hash, hashWith) class ( Typeable v @@ -135,23 +143,20 @@ class ( Typeable v rawDeserialiseVerKeyDSIGNM :: ByteString -> Maybe (VerKeyDSIGNM v) rawDeserialiseSigDSIGNM :: ByteString -> Maybe (SigDSIGNM v) -class ( DSIGNMAlgorithmBase v - , Monad m - ) - => DSIGNMAlgorithm m v where +class DSIGNMAlgorithmBase v => DSIGNMAlgorithm v where -- -- Metadata and basic key operations -- - deriveVerKeyDSIGNM :: SignKeyDSIGNM v -> m (VerKeyDSIGNM v) + deriveVerKeyDSIGNM :: (MonadThrow m, MonadST m) => SignKeyDSIGNM v -> m (VerKeyDSIGNM v) -- -- Core algorithm operations -- signDSIGNM - :: (SignableM v a, HasCallStack) + :: (SignableM v a, MonadST m, MonadThrow m) => ContextDSIGNM v -> a -> SignKeyDSIGNM v @@ -161,29 +166,69 @@ class ( DSIGNMAlgorithmBase v -- Key generation -- - genKeyDSIGNM :: MLockedSeed (SeedSizeDSIGNM v) -> m (SignKeyDSIGNM v) + genKeyDSIGNMWith :: (MonadST m, MonadThrow m) + => MLockedAllocator m + -> MLockedSeed (SeedSizeDSIGNM v) + -> m (SignKeyDSIGNM v) - cloneKeyDSIGNM :: SignKeyDSIGNM v -> m (SignKeyDSIGNM v) + cloneKeyDSIGNMWith :: MonadST m => MLockedAllocator m -> SignKeyDSIGNM v -> m (SignKeyDSIGNM v) - getSeedDSIGNM :: Proxy v -> SignKeyDSIGNM v -> m (MLockedSeed (SeedSizeDSIGNM v)) + getSeedDSIGNMWith :: (MonadST m, MonadThrow m) + => MLockedAllocator m + -> Proxy v + -> SignKeyDSIGNM v + -> m (MLockedSeed (SeedSizeDSIGNM v)) -- -- Secure forgetting -- - forgetSignKeyDSIGNM :: SignKeyDSIGNM v -> m () + forgetSignKeyDSIGNMWith :: (MonadST m, MonadThrow m) => MLockedAllocator m -> SignKeyDSIGNM v -> m () + + +forgetSignKeyDSIGNM :: (DSIGNMAlgorithm v, MonadST m, MonadThrow m) => SignKeyDSIGNM v -> m () +forgetSignKeyDSIGNM = forgetSignKeyDSIGNMWith mlockedMalloc + + +genKeyDSIGNM :: + (DSIGNMAlgorithm v, MonadST m, MonadThrow m) + => MLockedSeed (SeedSizeDSIGNM v) + -> m (SignKeyDSIGNM v) +genKeyDSIGNM = genKeyDSIGNMWith mlockedMalloc + +cloneKeyDSIGNM :: + (DSIGNMAlgorithm v, MonadST m) => SignKeyDSIGNM v -> m (SignKeyDSIGNM v) +cloneKeyDSIGNM = cloneKeyDSIGNMWith mlockedMalloc + +getSeedDSIGNM :: + (DSIGNMAlgorithm v, MonadST m, MonadThrow m) + => Proxy v + -> SignKeyDSIGNM v + -> m (MLockedSeed (SeedSizeDSIGNM v)) +getSeedDSIGNM = getSeedDSIGNMWith mlockedMalloc + -- | Unsound operations on DSIGNM sign keys. These operations violate secure -- forgetting constraints by leaking secrets to unprotected memory. Consider -- using the 'DirectSerialise' / 'DirectDeserialise' APIs instead. -class DSIGNMAlgorithm m v => UnsoundDSIGNMAlgorithm m v where +class DSIGNMAlgorithm v => UnsoundDSIGNMAlgorithm v where -- -- Serialisation/(de)serialisation in fixed-size raw format -- - rawSerialiseSignKeyDSIGNM :: SignKeyDSIGNM v -> m ByteString + rawSerialiseSignKeyDSIGNM :: + (MonadST m, MonadThrow m) => SignKeyDSIGNM v -> m ByteString + + rawDeserialiseSignKeyDSIGNMWith :: + (MonadST m, MonadThrow m) => MLockedAllocator m -> ByteString -> m (Maybe (SignKeyDSIGNM v)) + +rawDeserialiseSignKeyDSIGNM :: + (UnsoundDSIGNMAlgorithm v, MonadST m, MonadThrow m) + => ByteString + -> m (Maybe (SignKeyDSIGNM v)) +rawDeserialiseSignKeyDSIGNM = + rawDeserialiseSignKeyDSIGNMWith mlockedMalloc - rawDeserialiseSignKeyDSIGNM :: ByteString -> m (Maybe (SignKeyDSIGNM v)) -- -- Do not provide Ord instances for keys, see #38 @@ -221,7 +266,10 @@ sizeSigDSIGNM _ = fromInteger (natVal (Proxy @(SizeSigDSIGNM v))) encodeVerKeyDSIGNM :: DSIGNMAlgorithmBase v => VerKeyDSIGNM v -> Encoding encodeVerKeyDSIGNM = encodeBytes . rawSerialiseVerKeyDSIGNM -encodeSignKeyDSIGNM :: (UnsoundDSIGNMAlgorithm m v) => SignKeyDSIGNM v -> m Encoding +encodeSignKeyDSIGNM :: + (UnsoundDSIGNMAlgorithm v, MonadST m, MonadThrow m) + => SignKeyDSIGNM v + -> m Encoding encodeSignKeyDSIGNM = fmap encodeBytes . rawSerialiseSignKeyDSIGNM encodeSigDSIGNM :: DSIGNMAlgorithmBase v => SigDSIGNM v -> Encoding @@ -242,7 +290,7 @@ decodeVerKeyDSIGNM = do actual = BS.length bs decodeSignKeyDSIGNM :: forall m v s - . (UnsoundDSIGNMAlgorithm m v) + . (UnsoundDSIGNMAlgorithm v, MonadST m, MonadThrow m) => Decoder s (m (SignKeyDSIGNM v)) decodeSignKeyDSIGNM = do bs <- decodeBytes @@ -282,7 +330,7 @@ instance DSIGNMAlgorithmBase v => NoThunks (SignedDSIGNM v a) -- use generic instance signedDSIGNM - :: (DSIGNMAlgorithm m v, SignableM v a) + :: (DSIGNMAlgorithm v, SignableM v a, MonadST m, MonadThrow m) => ContextDSIGNM v -> a -> SignKeyDSIGNM v diff --git a/cardano-crypto-class/src/Cardano/Crypto/MEqOrd.hs b/cardano-crypto-class/src/Cardano/Crypto/EqST.hs similarity index 54% rename from cardano-crypto-class/src/Cardano/Crypto/MEqOrd.hs rename to cardano-crypto-class/src/Cardano/Crypto/EqST.hs index f041a6fc5..ffdae7e0b 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/MEqOrd.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/EqST.hs @@ -1,57 +1,58 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -module Cardano.Crypto.MEqOrd -where +module Cardano.Crypto.EqST where + +import Control.Monad.Class.MonadST (MonadST) -- | Monadic flavor of 'Eq', for things that can only be compared in a monadic --- context. +-- context that satisfies 'MonadST'. -- This is needed because we cannot have a sound 'Eq' instance on mlocked --- memory types. -class MEq m a where - equalsM :: a -> a -> m Bool +-- memory types, but we do need to compare them for equality in tests. +class EqST a where + equalsM :: MonadST m => a -> a -> m Bool -nequalsM :: (Functor m, MEq m a) => a -> a -> m Bool +nequalsM :: (MonadST m, EqST a) => a -> a -> m Bool nequalsM a b = not <$> equalsM a b -- | Infix version of 'equalsM' -(==!) :: MEq m a => a -> a -> m Bool +(==!) :: (MonadST m, EqST a) => a -> a -> m Bool (==!) = equalsM infix 4 ==! -- | Infix version of 'nequalsM' -(!=!) :: (Functor m, MEq m a) => a -> a -> m Bool +(!=!) :: (MonadST m, EqST a) => a -> a -> m Bool (!=!) = nequalsM infix 4 !=! -instance (Applicative m, MEq m a) => MEq m (Maybe a) where +instance EqST a => EqST (Maybe a) where equalsM Nothing Nothing = pure True equalsM (Just a) (Just b) = equalsM a b equalsM _ _ = pure False -instance (Applicative m, MEq m a, MEq m b) => MEq m (Either a b) where +instance (EqST a, EqST b) => EqST (Either a b) where equalsM (Left x) (Left y) = equalsM x y equalsM (Right x) (Right y) = equalsM x y equalsM _ _ = pure False -instance (Applicative m, MEq m a, MEq m b) => MEq m (a, b) where +instance (EqST a, EqST b) => EqST (a, b) where equalsM (a, b) (a', b') = (&&) <$> equalsM a a' <*> equalsM b b' -instance (Applicative m, MEq m a, MEq m b, MEq m c) => MEq m (a, b, c) where +instance (EqST a, EqST b, EqST c) => EqST (a, b, c) where equalsM (a, b, c) (a', b', c') = equalsM ((a, b), c) ((a', b'), c') -instance (Applicative m, MEq m a, MEq m b, MEq m c, MEq m d) => MEq m (a, b, c, d) where +instance (EqST a, EqST b, EqST c, EqST d) => EqST (a, b, c, d) where equalsM (a, b, c, d) (a', b', c', d') = equalsM ((a, b, c), d) ((a', b', c'), d') -- TODO: If anyone needs larger tuples, add more instances here... --- | Helper newtype, useful for defining 'MEq' in terms of 'Eq' for types that +-- | Helper newtype, useful for defining 'EqST' in terms of 'Eq' for types that -- have sound 'Eq' instances, using @DerivingVia@. An 'Applicative' context -- must be provided for such instances to work, so this will generally require -- @StandaloneDeriving@ as well. -- --- Ex.: @deriving via PureEq Int instance Applicative m => MEq m Int@ -newtype PureMEq a = PureMEq a +-- Ex.: @deriving via PureEq Int instance Applicative m => EqST m Int@ +newtype PureEqST a = PureEqST a -instance (Applicative m, Eq a) => MEq m (PureMEq a) where - equalsM (PureMEq a) (PureMEq b) = pure (a == b) +instance Eq a => EqST (PureEqST a) where + equalsM (PureEqST a) (PureEqST b) = pure (a == b) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs index 18f721728..6e61adb94 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Class.hs @@ -9,6 +9,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE RankNTypes #-} -- | Abstract key evolving signatures. module Cardano.Crypto.KES.Class @@ -16,6 +17,9 @@ module Cardano.Crypto.KES.Class -- * KES algorithm class KESAlgorithm (..) , KESSignAlgorithm (..) + , genKeyKES + , updateKES + , forgetSignKeyKES , Period , OptimizedKESAlgorithm (..) @@ -49,6 +53,7 @@ module Cardano.Crypto.KES.Class , UnsoundKESSignAlgorithm (..) , encodeSignKeyKES , decodeSignKeyKES + , rawDeserialiseSignKeyKES -- * Utility functions -- These are used between multiple KES implementations. User code will @@ -71,11 +76,14 @@ import GHC.Generics (Generic) import GHC.Stack import GHC.TypeLits (Nat, KnownNat, natVal, TypeError, ErrorMessage (..)) import NoThunks.Class (NoThunks) +import Control.Monad.Class.MonadST (MonadST) +import Control.Monad.Class.MonadThrow (MonadThrow) import Cardano.Binary (Decoder, decodeBytes, Encoding, encodeBytes, Size, withWordSize) import Cardano.Crypto.Util (Empty) -import Cardano.Crypto.MLockedSeed +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.Libsodium (MLockedAllocator, mlockedMalloc) import Cardano.Crypto.Hash.Class (HashAlgorithm, Hash, hashWith) import Cardano.Crypto.DSIGN.Class (failSizeCheck) @@ -172,58 +180,39 @@ seedSizeKES :: forall v proxy. KESAlgorithm v => proxy v -> Word seedSizeKES _ = fromInteger (natVal (Proxy @(SeedSizeKES v))) -class ( KESAlgorithm v - , Monad m - ) - => KESSignAlgorithm m v where +class KESAlgorithm v => KESSignAlgorithm v where data SignKeyKES v :: Type - deriveVerKeyKES :: SignKeyKES v -> m (VerKeyKES v) + deriveVerKeyKES :: (MonadST m, MonadThrow m) => SignKeyKES v -> m (VerKeyKES v) -- -- Core algorithm operations -- signKES - :: forall a. (Signable v a, HasCallStack) + :: forall a m. (Signable v a, MonadST m, MonadThrow m) => ContextKES v -> Period -- ^ The /current/ period for the key -> a -> SignKeyKES v -> m (SigKES v) - -- | Update the KES signature key to the /next/ period, given the /current/ - -- period. - -- - -- It returns 'Nothing' if the cannot be evolved any further. - -- - -- The precondition (to get a 'Just' result) is that the current KES period - -- of the input key is not the last period. The given period must be the - -- current KES period of the input key (not the next or target). - -- - -- The postcondition is that in case a key is returned, its current KES - -- period is incremented by one compared to before. - -- - -- Note that you must track the current period separately, and to skip to a - -- later period requires repeated use of this function, since it only - -- increments one period at once. - -- - updateKES - :: HasCallStack - => ContextKES v + updateKESWith + :: (MonadST m, MonadThrow m) + => MLockedAllocator m + -> ContextKES v -> SignKeyKES v -> Period -- ^ The /current/ period for the key, not the target period. -> m (Maybe (SignKeyKES v)) - -- - -- Key generation - -- - - genKeyKES - :: MLockedSeed (SeedSizeKES v) + genKeyKESWith + :: (MonadST m, MonadThrow m) + => MLockedAllocator m + -> MLockedSeed (SeedSizeKES v) -> m (SignKeyKES v) + -- -- Secure forgetting -- @@ -234,16 +223,75 @@ class ( KESAlgorithm v -- -- The precondition is that this key value will not be used again. -- - forgetSignKeyKES - :: SignKeyKES v + forgetSignKeyKESWith + :: (MonadST m, MonadThrow m) + => MLockedAllocator m + -> SignKeyKES v -> m () +-- | Forget a signing key synchronously, rather than waiting for GC. In some +-- non-mock instances this provides a guarantee that the signing key is no +-- longer in memory. +-- +-- The precondition is that this key value will not be used again. +-- +forgetSignKeyKES + :: (KESSignAlgorithm v, MonadST m, MonadThrow m) + => SignKeyKES v + -> m () +forgetSignKeyKES = forgetSignKeyKESWith mlockedMalloc + +-- | Key generation +-- +genKeyKES + :: forall v m. (KESSignAlgorithm v, MonadST m, MonadThrow m) + => MLockedSeed (SeedSizeKES v) + -> m (SignKeyKES v) +genKeyKES = genKeyKESWith mlockedMalloc + + +-- | Update the KES signature key to the /next/ period, given the /current/ +-- period. +-- +-- It returns 'Nothing' if the cannot be evolved any further. +-- +-- The precondition (to get a 'Just' result) is that the current KES period +-- of the input key is not the last period. The given period must be the +-- current KES period of the input key (not the next or target). +-- +-- The postcondition is that in case a key is returned, its current KES +-- period is incremented by one compared to before. +-- +-- Note that you must track the current period separately, and to skip to a +-- later period requires repeated use of this function, since it only +-- increments one period at once. +-- +updateKES + :: forall v m. (KESSignAlgorithm v, MonadST m, MonadThrow m) + => ContextKES v + -> SignKeyKES v + -> Period -- ^ The /current/ period for the key, not the target period. + -> m (Maybe (SignKeyKES v)) +updateKES = updateKESWith mlockedMalloc + + -- | Unsound operations on KES sign keys. These operations violate secure -- forgetting constraints by leaking secrets to unprotected memory. Consider -- using the 'DirectSerialise' / 'DirectDeserialise' APIs instead. -class (KESSignAlgorithm m v) => UnsoundKESSignAlgorithm m v where - rawDeserialiseSignKeyKES :: ByteString -> m (Maybe (SignKeyKES v)) - rawSerialiseSignKeyKES :: SignKeyKES v -> m ByteString +class KESSignAlgorithm v => UnsoundKESSignAlgorithm v where + rawDeserialiseSignKeyKESWith :: (MonadST m, MonadThrow m) + => MLockedAllocator m + -> ByteString + -> m (Maybe (SignKeyKES v)) + + rawSerialiseSignKeyKES :: (MonadST m, MonadThrow m) => SignKeyKES v -> m ByteString + +rawDeserialiseSignKeyKES :: + (UnsoundKESSignAlgorithm v, MonadST m, MonadThrow m) + => ByteString + -> m (Maybe (SignKeyKES v)) +rawDeserialiseSignKeyKES = rawDeserialiseSignKeyKESWith mlockedMalloc + -- | Subclass for KES algorithms that embed a copy of the VerKey into the -- signature itself, rather than relying on the externally supplied VerKey @@ -315,7 +363,10 @@ encodeVerKeyKES = encodeBytes . rawSerialiseVerKeyKES encodeSigKES :: KESAlgorithm v => SigKES v -> Encoding encodeSigKES = encodeBytes . rawSerialiseSigKES -encodeSignKeyKES :: forall v m. (UnsoundKESSignAlgorithm m v) => SignKeyKES v -> m Encoding +encodeSignKeyKES :: + forall v m. (UnsoundKESSignAlgorithm v, MonadST m, MonadThrow m) + => SignKeyKES v + -> m Encoding encodeSignKeyKES = fmap encodeBytes . rawSerialiseSignKeyKES decodeVerKeyKES :: forall v s. KESAlgorithm v => Decoder s (VerKeyKES v) @@ -334,7 +385,9 @@ decodeSigKES = do Nothing -> failSizeCheck "decodeSigKES" "signature" bs (sizeSigKES (Proxy :: Proxy v)) {-# INLINE decodeSigKES #-} -decodeSignKeyKES :: forall v s m. (UnsoundKESSignAlgorithm m v) => Decoder s (m (Maybe (SignKeyKES v))) +decodeSignKeyKES :: + forall v s m. (UnsoundKESSignAlgorithm v, MonadST m, MonadThrow m) + => Decoder s (m (Maybe (SignKeyKES v))) decodeSignKeyKES = do bs <- decodeBytes let expected = fromIntegral (sizeSignKeyKES (Proxy @v)) @@ -362,13 +415,13 @@ instance KESAlgorithm v => NoThunks (SignedKES v a) -- use generic instance signedKES - :: (KESSignAlgorithm m v, Signable v a) + :: (KESSignAlgorithm v, Signable v a, MonadST m, MonadThrow m) => ContextKES v -> Period -> a -> SignKeyKES v -> m (SignedKES v a) -signedKES ctxt time a key = SignedKES <$> (signKES ctxt time a key) +signedKES ctxt time a key = SignedKES <$> signKES ctxt time a key verifySignedKES :: (KESAlgorithm v, Signable v a) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs index 2a58a480b..07cc9a9a9 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSingle.hs @@ -141,10 +141,10 @@ instance ( DSIGNMAlgorithmBase d off_sig = 0 :: Word off_vk = size_sig -instance ( DSIGNMAlgorithm m d -- needed for secure forgetting +instance ( DSIGNMAlgorithm d -- needed for secure forgetting , KnownNat (SizeSigDSIGNM d + SizeVerKeyDSIGNM d) ) - => KESSignAlgorithm m (CompactSingleKES d) where + => KESSignAlgorithm (CompactSingleKES d) where newtype SignKeyKES (CompactSingleKES d) = SignKeyCompactSingleKES (SignKeyDSIGNM d) deriveVerKeyKES (SignKeyCompactSingleKES v) = @@ -157,19 +157,19 @@ instance ( DSIGNMAlgorithm m d -- needed for secure forgetting assert (t == 0) $ SigCompactSingleKES <$!> signDSIGNM ctxt a sk <*> deriveVerKeyDSIGNM sk - updateKES _ctx (SignKeyCompactSingleKES _sk) _to = return Nothing + updateKESWith _allocator _ctx (SignKeyCompactSingleKES _sk) _to = return Nothing -- -- Key generation -- - genKeyKES seed = SignKeyCompactSingleKES <$!> genKeyDSIGNM seed + genKeyKESWith allocator seed = SignKeyCompactSingleKES <$!> genKeyDSIGNMWith allocator seed -- -- forgetting -- - forgetSignKeyKES (SignKeyCompactSingleKES v) = - forgetSignKeyDSIGNM v + forgetSignKeyKESWith allocator (SignKeyCompactSingleKES v) = + forgetSignKeyDSIGNMWith allocator v instance ( KESAlgorithm (CompactSingleKES d) , DSIGNMAlgorithmBase d @@ -182,10 +182,10 @@ instance ( KESAlgorithm (CompactSingleKES d) assert (t == 0) $ VerKeyCompactSingleKES vk -instance (KESSignAlgorithm m (CompactSingleKES d), UnsoundDSIGNMAlgorithm m d) - => UnsoundKESSignAlgorithm m (CompactSingleKES d) where +instance (KESSignAlgorithm (CompactSingleKES d), UnsoundDSIGNMAlgorithm d) + => UnsoundKESSignAlgorithm (CompactSingleKES d) where rawSerialiseSignKeyKES (SignKeyCompactSingleKES sk) = rawSerialiseSignKeyDSIGNM sk - rawDeserialiseSignKeyKES bs = fmap SignKeyCompactSingleKES <$> rawDeserialiseSignKeyDSIGNM bs + rawDeserialiseSignKeyKESWith allocator bs = fmap SignKeyCompactSingleKES <$> rawDeserialiseSignKeyDSIGNMWith allocator bs -- diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs index bb4075ec4..6f7e247e2 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/CompactSum.hs @@ -1,6 +1,5 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -8,7 +7,6 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -97,10 +95,8 @@ import Cardano.Crypto.Hash.Class import Cardano.Crypto.KES.Class import Cardano.Crypto.KES.CompactSingle (CompactSingleKES) import Cardano.Crypto.Util -import Cardano.Crypto.MLockedSeed -import qualified Cardano.Crypto.MonadSodium as NaCl -import Control.Monad.Class.MonadST (MonadST) -import Control.Monad.Class.MonadThrow (MonadThrow) +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.Libsodium import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.Monad.Trans (lift) import Control.DeepSeq (NFData (..)) @@ -150,7 +146,7 @@ instance (NFData (SignKeyKES d), NFData (VerKeyKES d)) => rnf (sk, r, vk1, vk2) instance ( OptimizedKESAlgorithm d - , NaCl.SodiumHashAlgorithm h -- needed for secure forgetting + , SodiumHashAlgorithm h -- needed for secure forgetting , SizeHash h ~ SeedSizeKES d -- can be relaxed , NoThunks (VerKeyKES (CompactSumKES h d)) , KnownNat (SizeVerKeyKES (CompactSumKES h d)) @@ -249,18 +245,15 @@ instance ( OptimizedKESAlgorithm d off_vk = size_sig instance ( OptimizedKESAlgorithm d - , KESSignAlgorithm m d - , NaCl.SodiumHashAlgorithm h -- needed for secure forgetting + , KESSignAlgorithm d + , SodiumHashAlgorithm h -- needed for secure forgetting , SizeHash h ~ SeedSizeKES d -- can be relaxed - , NaCl.MonadSodium m - , MonadST m -- only needed for unsafe raw ser/deser - , MonadThrow m , NoThunks (VerKeyKES (CompactSumKES h d)) , KnownNat (SizeVerKeyKES (CompactSumKES h d)) , KnownNat (SizeSignKeyKES (CompactSumKES h d)) , KnownNat (SizeSigKES (CompactSumKES h d)) ) - => KESSignAlgorithm m (CompactSumKES h d) where + => KESSignAlgorithm (CompactSumKES h d) where -- | From Figure 3: @(sk_0, r_1, vk_0, vk_1)@ -- data SignKeyKES (CompactSumKES h d) = @@ -284,22 +277,22 @@ instance ( OptimizedKESAlgorithm d _T = totalPeriodsKES (Proxy :: Proxy d) - {-# NOINLINE updateKES #-} - updateKES ctx (SignKeyCompactSumKES sk r_1 vk_0 vk_1) t + {-# NOINLINE updateKESWith #-} + updateKESWith allocator ctx (SignKeyCompactSumKES sk r_1 vk_0 vk_1) t | t+1 < _T = runMaybeT $! do - sk' <- MaybeT $! updateKES ctx sk t - r_1' <- lift $! mlockedSeedCopy r_1 + sk' <- MaybeT $! updateKESWith allocator ctx sk t + r_1' <- lift $! mlockedSeedCopyWith allocator r_1 return $! SignKeyCompactSumKES sk' r_1' vk_0 vk_1 | t+1 == _T = do - sk' <- genKeyKES r_1 - zero <- mlockedSeedNewZero + sk' <- genKeyKESWith allocator r_1 + zero <- mlockedSeedNewZeroWith allocator return $! Just $! SignKeyCompactSumKES sk' zero vk_0 vk_1 | otherwise = runMaybeT $! do - sk' <- MaybeT $! updateKES ctx sk (t - _T) - r_1' <- lift $! mlockedSeedCopy r_1 + sk' <- MaybeT $! updateKESWith allocator ctx sk (t - _T) + r_1' <- lift $! mlockedSeedCopyWith allocator r_1 return $! SignKeyCompactSumKES sk' r_1' vk_0 vk_1 where _T = totalPeriodsKES (Proxy :: Proxy d) @@ -308,14 +301,14 @@ instance ( OptimizedKESAlgorithm d -- Key generation -- - {-# NOINLINE genKeyKES #-} - genKeyKES r = do - (r0raw, r1raw) <- NaCl.expandHash (Proxy :: Proxy h) (mlockedSeedMLSB r) + {-# NOINLINE genKeyKESWith #-} + genKeyKESWith allocator r = do + (r0raw, r1raw) <- expandHashWith allocator (Proxy :: Proxy h) (mlockedSeedMLSB r) let r0 = MLockedSeed r0raw r1 = MLockedSeed r1raw - sk_0 <- genKeyKES r0 + sk_0 <- genKeyKESWith allocator r0 vk_0 <- deriveVerKeyKES sk_0 - sk_1 <- genKeyKES r1 + sk_1 <- genKeyKESWith allocator r1 vk_1 <- deriveVerKeyKES sk_1 forgetSignKeyKES sk_1 mlockedSeedFinalize r0 @@ -324,15 +317,13 @@ instance ( OptimizedKESAlgorithm d -- -- forgetting -- - forgetSignKeyKES (SignKeyCompactSumKES sk_0 r1 _ _) = do - forgetSignKeyKES sk_0 + forgetSignKeyKESWith allocator (SignKeyCompactSumKES sk_0 r1 _ _) = do + forgetSignKeyKESWith allocator sk_0 mlockedSeedFinalize r1 -instance ( KESSignAlgorithm m (CompactSumKES h d) - , UnsoundKESSignAlgorithm m d - , NaCl.MonadSodium m - , MonadST m - ) => UnsoundKESSignAlgorithm m (CompactSumKES h d) where +instance ( KESSignAlgorithm (CompactSumKES h d) + , UnsoundKESSignAlgorithm d + ) => UnsoundKESSignAlgorithm (CompactSumKES h d) where -- -- Raw serialise/deserialise - dangerous, do not use in production code. -- @@ -340,7 +331,7 @@ instance ( KESSignAlgorithm m (CompactSumKES h d) {-# NOINLINE rawSerialiseSignKeyKES #-} rawSerialiseSignKeyKES (SignKeyCompactSumKES sk r_1 vk_0 vk_1) = do ssk <- rawSerialiseSignKeyKES sk - sr1 <- NaCl.mlsbToByteString . mlockedSeedMLSB $ r_1 + sr1 <- mlsbToByteString . mlockedSeedMLSB $ r_1 return $ mconcat [ ssk , sr1 @@ -348,11 +339,11 @@ instance ( KESSignAlgorithm m (CompactSumKES h d) , rawSerialiseVerKeyKES vk_1 ] - {-# NOINLINE rawDeserialiseSignKeyKES #-} - rawDeserialiseSignKeyKES b = runMaybeT $ do + {-# NOINLINE rawDeserialiseSignKeyKESWith #-} + rawDeserialiseSignKeyKESWith allocator b = runMaybeT $ do guard (BS.length b == fromIntegral size_total) - sk <- MaybeT $ rawDeserialiseSignKeyKES b_sk - r <- MaybeT $ NaCl.mlsbFromByteStringCheck b_r + sk <- MaybeT $ rawDeserialiseSignKeyKESWith allocator b_sk + r <- MaybeT $ mlsbFromByteStringCheckWith allocator b_r vk_0 <- MaybeT . return $ rawDeserialiseVerKeyKES b_vk0 vk_1 <- MaybeT . return $ rawDeserialiseVerKeyKES b_vk1 return (SignKeyCompactSumKES sk (MLockedSeed r) vk_0 vk_1) @@ -406,7 +397,7 @@ deriving via OnlyCheckWhnfNamed "SignKeyKES (CompactSumKES h d)" (SignKeyKES (Co instance (KESAlgorithm d) => NoThunks (VerKeyKES (CompactSumKES h d)) instance ( OptimizedKESAlgorithm d - , NaCl.SodiumHashAlgorithm h + , SodiumHashAlgorithm h , SizeHash h ~ SeedSizeKES d , NoThunks (VerKeyKES (CompactSumKES h d)) , KnownNat (SizeVerKeyKES (CompactSumKES h d)) @@ -418,7 +409,7 @@ instance ( OptimizedKESAlgorithm d encodedSizeExpr _size = encodedVerKeyKESSizeExpr instance ( OptimizedKESAlgorithm d - , NaCl.SodiumHashAlgorithm h + , SodiumHashAlgorithm h , SizeHash h ~ SeedSizeKES d , NoThunks (VerKeyKES (CompactSumKES h d)) , KnownNat (SizeVerKeyKES (CompactSumKES h d)) @@ -458,7 +449,7 @@ deriving instance KESAlgorithm d => Eq (SigKES (CompactSumKES h d)) instance KESAlgorithm d => NoThunks (SigKES (CompactSumKES h d)) instance ( OptimizedKESAlgorithm d - , NaCl.SodiumHashAlgorithm h + , SodiumHashAlgorithm h , SizeHash h ~ SeedSizeKES d , NoThunks (VerKeyKES (CompactSumKES h d)) , KnownNat (SizeVerKeyKES (CompactSumKES h d)) @@ -470,7 +461,7 @@ instance ( OptimizedKESAlgorithm d encodedSizeExpr _size = encodedSigKESSizeExpr instance ( OptimizedKESAlgorithm d - , NaCl.SodiumHashAlgorithm h + , SodiumHashAlgorithm h , SizeHash h ~ SeedSizeKES d , NoThunks (VerKeyKES (CompactSumKES h d)) , KnownNat (SizeVerKeyKES (CompactSumKES h d)) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs index 17c8e9cac..ed0d7480d 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Mock.hs @@ -33,8 +33,10 @@ import Cardano.Crypto.Hash import Cardano.Crypto.Seed import Cardano.Crypto.KES.Class import Cardano.Crypto.Util -import Cardano.Crypto.MLockedSeed -import Cardano.Crypto.MonadSodium (mlsbAsByteString) +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.Libsodium + ( mlsbAsByteString + ) data MockKES (t :: Nat) @@ -122,7 +124,7 @@ instance KnownNat t => KESAlgorithm (MockKES t) where | otherwise = Nothing -instance (Monad m, KnownNat t) => KESSignAlgorithm m (MockKES t) where +instance KnownNat t => KESSignAlgorithm (MockKES t) where data SignKeyKES (MockKES t) = SignKeyMockKES !(VerKeyKES (MockKES t)) !Period deriving stock (Show, Eq, Generic) @@ -130,11 +132,11 @@ instance (Monad m, KnownNat t) => KESSignAlgorithm m (MockKES t) where deriveVerKeyKES (SignKeyMockKES vk _) = return $! vk - updateKES () (SignKeyMockKES vk t') t = + updateKESWith _allocator () (SignKeyMockKES vk t') t = assert (t == t') $! if t+1 < totalPeriodsKES (Proxy @(MockKES t)) then return $! Just $! SignKeyMockKES vk (t+1) - else return $! Nothing + else return Nothing -- | Produce valid signature only with correct key, i.e., same iteration and -- allowed KES period. @@ -148,17 +150,17 @@ instance (Monad m, KnownNat t) => KESSignAlgorithm m (MockKES t) where -- Key generation -- - genKeyKES seed = do + genKeyKESWith _allocator seed = do let vk = VerKeyMockKES (runMonadRandomWithSeed (mkSeedFromBytes . mlsbAsByteString . mlockedSeedMLSB $ seed) getRandomWord64) return $! SignKeyMockKES vk 0 - forgetSignKeyKES = const $ return () + forgetSignKeyKESWith _ = const $ return () -instance (Monad m, KnownNat t) => UnsoundKESSignAlgorithm m (MockKES t) where +instance KnownNat t => UnsoundKESSignAlgorithm (MockKES t) where rawSerialiseSignKeyKES sk = return $ rawSerialiseSignKeyMockKES sk - rawDeserialiseSignKeyKES bs = + rawDeserialiseSignKeyKESWith _alloc bs = return $ rawDeserialiseSignKeyMockKES bs rawDeserialiseSignKeyMockKES :: KnownNat t diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs index 40f4bc74f..8aaf910ab 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/NeverUsed.hs @@ -49,20 +49,20 @@ instance KESAlgorithm NeverKES where rawDeserialiseVerKeyKES _ = Just NeverUsedVerKeyKES rawDeserialiseSigKES _ = Just NeverUsedSigKES -instance Monad m => KESSignAlgorithm m NeverKES where +instance KESSignAlgorithm NeverKES where data SignKeyKES NeverKES = NeverUsedSignKeyKES deriving (Show, Eq, Generic, NoThunks) - deriveVerKeyKES _ = return $! NeverUsedVerKeyKES + deriveVerKeyKES _ = return NeverUsedVerKeyKES signKES = error "KES not available" - updateKES = error "KES not available" + updateKESWith _ = error "KES not available" - genKeyKES _ = return $! NeverUsedSignKeyKES + genKeyKESWith _ _ = return NeverUsedSignKeyKES - forgetSignKeyKES = const $ return () + forgetSignKeyKESWith _ = const $ return () -instance Monad m => UnsoundKESSignAlgorithm m NeverKES where +instance UnsoundKESSignAlgorithm NeverKES where rawSerialiseSignKeyKES _ = return mempty - rawDeserialiseSignKeyKES _ = return $ Just NeverUsedSignKeyKES + rawDeserialiseSignKeyKESWith _ _ = return $ Just NeverUsedSignKeyKES diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs index 1ace81793..c40bcffc1 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs @@ -33,19 +33,17 @@ import GHC.Generics (Generic) import GHC.TypeNats (Nat, KnownNat, natVal, type (*)) import NoThunks.Class (NoThunks) import Control.Monad.Trans.Maybe -import Control.Monad.Class.MonadThrow (MonadEvaluate) -import Control.Monad.Class.MonadST (MonadST) import Control.Monad ( (<$!>) ) import Cardano.Binary (FromCBOR (..), ToCBOR (..)) import Cardano.Crypto.DSIGN import Cardano.Crypto.KES.Class -import Cardano.Crypto.MLockedSeed +import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium.MLockedBytes import Cardano.Crypto.Util import Data.Unit.Strict (forceElemsToWHNF) -import Cardano.Crypto.MonadSodium (MonadSodium (..), MEq (..)) +import Cardano.Crypto.EqST (EqST (..)) data SimpleKES d (t :: Nat) @@ -144,14 +142,11 @@ instance ( DSIGNMAlgorithmBase d instance ( KESAlgorithm (SimpleKES d t) - , DSIGNMAlgorithm m d + , DSIGNMAlgorithm d , KnownNat t , KnownNat (SeedSizeDSIGNM d * t) - , MonadEvaluate m - , MonadSodium m - , MonadST m ) => - KESSignAlgorithm m (SimpleKES d t) where + KESSignAlgorithm (SimpleKES d t) where newtype SignKeyKES (SimpleKES d t) = ThunkySignKeySimpleKES (Vector (SignKeyDSIGNM d)) deriving Generic @@ -166,9 +161,9 @@ instance ( KESAlgorithm (SimpleKES d t) Nothing -> error ("SimpleKES.signKES: period out of range " ++ show j) Just sk -> SigSimpleKES <$!> (signDSIGNM ctxt a $! sk) - updateKES _ (ThunkySignKeySimpleKES sk) t + updateKESWith allocator _ (ThunkySignKeySimpleKES sk) t | t+1 < fromIntegral (natVal (Proxy @t)) = do - sk' <- Vec.mapM cloneKeyDSIGNM sk + sk' <- Vec.mapM (cloneKeyDSIGNMWith allocator) sk return $! Just $! SignKeySimpleKES sk' | otherwise = return Nothing @@ -177,24 +172,25 @@ instance ( KESAlgorithm (SimpleKES d t) -- Key generation -- - genKeyKES (MLockedSeed mlsb) = do + genKeyKESWith allocator (MLockedSeed mlsb) = do let seedSize = seedSizeDSIGNM (Proxy :: Proxy d) duration = fromIntegral (natVal (Proxy @t)) sks <- Vec.generateM duration $ \t -> do withMLSBChunk mlsb (fromIntegral t * fromIntegral seedSize) $ \mlsb' -> do - genKeyDSIGNM (MLockedSeed mlsb') + genKeyDSIGNMWith allocator (MLockedSeed mlsb') return $! SignKeySimpleKES sks -- -- Forgetting -- - forgetSignKeyKES (SignKeySimpleKES sks) = Vec.mapM_ forgetSignKeyDSIGNM sks + forgetSignKeyKESWith allocator (SignKeySimpleKES sks) = + Vec.mapM_ (forgetSignKeyDSIGNMWith allocator) sks -instance ( UnsoundDSIGNMAlgorithm m d, KnownNat t, KESSignAlgorithm m (SimpleKES d t)) - => UnsoundKESSignAlgorithm m (SimpleKES d t) where +instance ( UnsoundDSIGNMAlgorithm d, KnownNat t, KESSignAlgorithm (SimpleKES d t)) + => UnsoundKESSignAlgorithm (SimpleKES d t) where -- -- raw serialise/deserialise -- @@ -203,13 +199,13 @@ instance ( UnsoundDSIGNMAlgorithm m d, KnownNat t, KESSignAlgorithm m (SimpleKES BS.concat <$!> mapM rawSerialiseSignKeyDSIGNM (Vec.toList sks) - rawDeserialiseSignKeyKES bs + rawDeserialiseSignKeyKESWith allocator bs | let duration = fromIntegral (natVal (Proxy :: Proxy t)) sizeKey = fromIntegral (sizeSignKeyDSIGNM (Proxy :: Proxy d)) , skbs <- splitsAt (replicate duration sizeKey) bs , length skbs == duration = runMaybeT $ do - sks <- mapM (MaybeT . rawDeserialiseSignKeyDSIGNM) skbs + sks <- mapM (MaybeT . rawDeserialiseSignKeyDSIGNMWith allocator) skbs return $! SignKeySimpleKES (Vec.fromList sks) | otherwise @@ -222,7 +218,7 @@ deriving instance DSIGNMAlgorithmBase d => Show (SigKES (SimpleKES d t)) deriving instance DSIGNMAlgorithmBase d => Eq (VerKeyKES (SimpleKES d t)) deriving instance DSIGNMAlgorithmBase d => Eq (SigKES (SimpleKES d t)) -instance (Monad m, MEq m (SignKeyDSIGNM d)) => MEq m (SignKeyKES (SimpleKES d t)) where +instance EqST (SignKeyDSIGNM d) => EqST (SignKeyKES (SimpleKES d t)) where equalsM (ThunkySignKeySimpleKES a) (ThunkySignKeySimpleKES b) = -- No need to check that lengths agree, the types already guarantee this. Vec.and <$> Vec.zipWithM equalsM a b diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs index 85045adce..c2c8aced2 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Single.hs @@ -111,8 +111,7 @@ instance (DSIGNMAlgorithmBase d) => KESAlgorithm (SingleKES d) where {-# INLINE rawDeserialiseSigKES #-} -instance ( DSIGNMAlgorithm m d -- needed for secure forgetting - ) => KESSignAlgorithm m (SingleKES d) where +instance DSIGNMAlgorithm d => KESSignAlgorithm (SingleKES d) where newtype SignKeyKES (SingleKES d) = SignKeySingleKES (SignKeyDSIGNM d) deriveVerKeyKES (SignKeySingleKES v) = @@ -126,27 +125,28 @@ instance ( DSIGNMAlgorithm m d -- needed for secure forgetting assert (t == 0) $! SigSingleKES <$!> signDSIGNM ctxt a sk - updateKES _ctx (SignKeySingleKES _sk) _to = return Nothing + updateKESWith _allocator _ctx (SignKeySingleKES _sk) _to = return Nothing -- -- Key generation -- - genKeyKES seed = SignKeySingleKES <$!> genKeyDSIGNM seed + genKeyKESWith allocator seed = + SignKeySingleKES <$!> genKeyDSIGNMWith allocator seed -- -- forgetting -- - forgetSignKeyKES (SignKeySingleKES v) = - forgetSignKeyDSIGNM v + forgetSignKeyKESWith allocator (SignKeySingleKES v) = + forgetSignKeyDSIGNMWith allocator v -instance (KESSignAlgorithm m (SingleKES d), UnsoundDSIGNMAlgorithm m d) - => UnsoundKESSignAlgorithm m (SingleKES d) where +instance (KESSignAlgorithm (SingleKES d), UnsoundDSIGNMAlgorithm d) + => UnsoundKESSignAlgorithm (SingleKES d) where rawSerialiseSignKeyKES (SignKeySingleKES sk) = rawSerialiseSignKeyDSIGNM sk - rawDeserialiseSignKeyKES bs = - fmap SignKeySingleKES <$> rawDeserialiseSignKeyDSIGNM bs + rawDeserialiseSignKeyKESWith allocator bs = + fmap SignKeySingleKES <$> rawDeserialiseSignKeyDSIGNMWith allocator bs -- -- VerKey instances diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs index ff0381c1b..23de3f54b 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Sum.hs @@ -63,14 +63,13 @@ import Cardano.Crypto.Hash.Class import Cardano.Crypto.KES.Class import Cardano.Crypto.KES.Single (SingleKES) import Cardano.Crypto.Util -import Cardano.Crypto.MLockedSeed -import qualified Cardano.Crypto.MonadSodium as NaCl -import Control.Monad.Class.MonadST (MonadST) -import Control.Monad.Class.MonadThrow (MonadThrow) +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.Libsodium import Control.Monad.Trans.Maybe (MaybeT (..), runMaybeT) import Control.DeepSeq (NFData (..)) import GHC.TypeLits (KnownNat, type (+), type (*)) + -- | A 2^0 period KES type Sum0KES d = SingleKES d @@ -115,7 +114,7 @@ instance (NFData (SignKeyKES d), NFData (VerKeyKES d)) => rnf (sk, r, vk1, vk2) instance ( KESAlgorithm d - , NaCl.SodiumHashAlgorithm h -- needed for secure forgetting + , SodiumHashAlgorithm h -- needed for secure forgetting , SizeHash h ~ SeedSizeKES d -- can be relaxed , KnownNat ((SizeSignKeyKES d + SeedSizeKES d) + (2 * SizeVerKeyKES d)) , KnownNat (SizeSigKES d + (SizeVerKeyKES d * 2)) @@ -221,16 +220,13 @@ instance ( KESAlgorithm d off_vk1 = off_vk0 + size_vk {-# INLINEABLE rawDeserialiseSigKES #-} -instance ( KESSignAlgorithm m d - , NaCl.SodiumHashAlgorithm h -- needed for secure forgetting +instance ( KESSignAlgorithm d + , SodiumHashAlgorithm h -- needed for secure forgetting , SizeHash h ~ SeedSizeKES d -- can be relaxed - , NaCl.MonadSodium m - , MonadST m -- only needed for unsafe raw ser/deser - , MonadThrow m , KnownNat ((SizeSignKeyKES d + SeedSizeKES d) + (2 * SizeVerKeyKES d)) , KnownNat (SizeSigKES d + (SizeVerKeyKES d * 2)) ) - => KESSignAlgorithm m (SumKES h d) where + => KESSignAlgorithm (SumKES h d) where -- | From Figure 3: @(sk_0, r_1, vk_0, vk_1)@ -- data SignKeyKES (SumKES h d) = @@ -253,21 +249,21 @@ instance ( KESSignAlgorithm m d _T = totalPeriodsKES (Proxy :: Proxy d) - {-# NOINLINE updateKES #-} - updateKES ctx (SignKeySumKES sk r_1 vk_0 vk_1) t + {-# NOINLINE updateKESWith #-} + updateKESWith allocator ctx (SignKeySumKES sk r_1 vk_0 vk_1) t | t+1 < _T = runMaybeT $! do - sk' <- MaybeT $! updateKES ctx sk t + sk' <- MaybeT $! updateKESWith allocator ctx sk t r_1' <- MaybeT $! Just <$!> mlockedSeedCopy r_1 return $! SignKeySumKES sk' r_1' vk_0 vk_1 | t+1 == _T = do - sk' <- genKeyKES r_1 - r_1' <- mlockedSeedNewZero + sk' <- genKeyKESWith allocator r_1 + r_1' <- mlockedSeedNewZeroWith allocator return $! Just $! SignKeySumKES sk' r_1' vk_0 vk_1 | otherwise = runMaybeT $ do - sk' <- MaybeT $! updateKES ctx sk (t - _T) - r_1' <- MaybeT $! Just <$!> mlockedSeedCopy r_1 + sk' <- MaybeT $! updateKESWith allocator ctx sk (t - _T) + r_1' <- MaybeT $! Just <$!> mlockedSeedCopyWith allocator r_1 return $! SignKeySumKES sk' r_1' vk_0 vk_1 where _T = totalPeriodsKES (Proxy :: Proxy d) @@ -276,14 +272,14 @@ instance ( KESSignAlgorithm m d -- Key generation -- - {-# NOINLINE genKeyKES #-} - genKeyKES r = do - (r0raw, r1raw) <- NaCl.expandHash (Proxy :: Proxy h) (mlockedSeedMLSB r) + {-# NOINLINE genKeyKESWith #-} + genKeyKESWith allocator r = do + (r0raw, r1raw) <- expandHashWith allocator (Proxy :: Proxy h) (mlockedSeedMLSB r) let r0 = MLockedSeed r0raw r1 = MLockedSeed r1raw - sk_0 <- genKeyKES r0 + sk_0 <- genKeyKESWith allocator r0 vk_0 <- deriveVerKeyKES sk_0 - sk_1 <- genKeyKES r1 + sk_1 <- genKeyKESWith allocator r1 vk_1 <- deriveVerKeyKES sk_1 forgetSignKeyKES sk_1 mlockedSeedFinalize r0 @@ -292,15 +288,13 @@ instance ( KESSignAlgorithm m d -- -- forgetting -- - forgetSignKeyKES (SignKeySumKES sk_0 r1 _ _) = do - forgetSignKeyKES sk_0 + forgetSignKeyKESWith allocator (SignKeySumKES sk_0 r1 _ _) = do + forgetSignKeyKESWith allocator sk_0 mlockedSeedFinalize r1 -instance ( KESSignAlgorithm m (SumKES h d) - , UnsoundKESSignAlgorithm m d - , NaCl.MonadSodium m - , MonadST m - ) => UnsoundKESSignAlgorithm m (SumKES h d) where +instance ( KESSignAlgorithm (SumKES h d) + , UnsoundKESSignAlgorithm d + ) => UnsoundKESSignAlgorithm (SumKES h d) where -- -- Raw serialise/deserialise - dangerous, do not use in production code. -- @@ -308,7 +302,7 @@ instance ( KESSignAlgorithm m (SumKES h d) {-# NOINLINE rawSerialiseSignKeyKES #-} rawSerialiseSignKeyKES (SignKeySumKES sk r_1 vk_0 vk_1) = do ssk <- rawSerialiseSignKeyKES sk - sr1 <- NaCl.mlsbToByteString . mlockedSeedMLSB $ r_1 + sr1 <- mlsbToByteString . mlockedSeedMLSB $ r_1 return $ mconcat [ ssk , sr1 @@ -316,11 +310,11 @@ instance ( KESSignAlgorithm m (SumKES h d) , rawSerialiseVerKeyKES vk_1 ] - {-# NOINLINE rawDeserialiseSignKeyKES #-} - rawDeserialiseSignKeyKES b = runMaybeT $ do + {-# NOINLINE rawDeserialiseSignKeyKESWith #-} + rawDeserialiseSignKeyKESWith allocator b = runMaybeT $ do guard (BS.length b == fromIntegral size_total) - sk <- MaybeT $ rawDeserialiseSignKeyKES b_sk - r <- MaybeT $ NaCl.mlsbFromByteStringCheck b_r + sk <- MaybeT $ rawDeserialiseSignKeyKESWith allocator b_sk + r <- MaybeT $ mlsbFromByteStringCheckWith allocator b_r vk_0 <- MaybeT . return $ rawDeserialiseVerKeyKES b_vk0 vk_1 <- MaybeT . return $ rawDeserialiseVerKeyKES b_vk1 return (SignKeySumKES sk (MLockedSeed r) vk_0 vk_1) @@ -348,12 +342,12 @@ instance ( KESSignAlgorithm m (SumKES h d) deriving instance HashAlgorithm h => Show (VerKeyKES (SumKES h d)) deriving instance Eq (VerKeyKES (SumKES h d)) -instance (KESAlgorithm (SumKES h d), NaCl.SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) +instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) => ToCBOR (VerKeyKES (SumKES h d)) where toCBOR = encodeVerKeyKES encodedSizeExpr _size = encodedVerKeyKESSizeExpr -instance (KESAlgorithm (SumKES h d), NaCl.SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) +instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) => FromCBOR (VerKeyKES (SumKES h d)) where fromCBOR = decodeVerKeyKES {-# INLINE fromCBOR #-} @@ -388,12 +382,12 @@ deriving instance (KESAlgorithm d, KESAlgorithm (SumKES h d)) => Eq (SigKES (Sum instance KESAlgorithm d => NoThunks (SigKES (SumKES h d)) -instance (KESAlgorithm (SumKES h d), NaCl.SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) +instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) => ToCBOR (SigKES (SumKES h d)) where toCBOR = encodeSigKES encodedSizeExpr _size = encodedSigKESSizeExpr -instance (KESAlgorithm (SumKES h d), NaCl.SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) +instance (KESAlgorithm (SumKES h d), SodiumHashAlgorithm h, SizeHash h ~ SeedSizeKES d) => FromCBOR (SigKES (SumKES h d)) where fromCBOR = decodeSigKES {-# INLINE fromCBOR #-} diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium.hs index 0201f7fc7..eaad15d05 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium.hs @@ -1,26 +1,48 @@ module Cardano.Crypto.Libsodium ( -- * Initialization sodiumInit, + -- * MLocked memory management MLockedForeignPtr, - withMLockedForeignPtr, - mlockedAllocForeignPtr, + MLockedAllocator, + finalizeMLockedForeignPtr, + mlockedAllocForeignPtr, + mlockedMalloc, traceMLockedForeignPtr, - -- * MLocked bytes + withMLockedForeignPtr, + + -- * MLocked bytes ('MLockedSizedBytes') MLockedSizedBytes, + + mlsbAsByteString, + mlsbCompare, + mlsbCopy, + mlsbCopyWith, + mlsbEq, + mlsbFinalize, mlsbFromByteString, mlsbFromByteStringCheck, - mlsbAsByteString, + mlsbFromByteStringCheckWith, + mlsbFromByteStringWith, + mlsbNew, + mlsbNewWith, + mlsbNewZero, + mlsbNewZeroWith, mlsbToByteString, - mlsbFinalize, - mlsbCopy, + mlsbUseAsCPtr, + mlsbUseAsSizedPtr, + mlsbZero, + -- * Hashing - SodiumHashAlgorithm (..), - digestMLockedStorable, digestMLockedBS, + digestMLockedStorable, expandHash, + expandHashWith, + SodiumHashAlgorithm (..), ) where import Cardano.Crypto.Libsodium.Init -import Cardano.Crypto.MonadSodium +import Cardano.Crypto.Libsodium.Memory +import Cardano.Crypto.Libsodium.Hash +import Cardano.Crypto.Libsodium.MLockedBytes diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Hash.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Hash.hs index fc9bd8073..87d89cf50 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Hash.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Hash.hs @@ -9,6 +9,7 @@ module Cardano.Crypto.Libsodium.Hash ( digestMLockedStorable, digestMLockedBS, expandHash, + expandHashWith, ) where import Data.Proxy (Proxy (..)) @@ -19,10 +20,9 @@ import Data.Word (Word8) import GHC.TypeLits import Cardano.Crypto.Hash (HashAlgorithm(SizeHash)) +import Cardano.Crypto.Libsodium.Memory import Cardano.Crypto.Libsodium.Hash.Class import Cardano.Crypto.Libsodium.MLockedBytes.Internal -import Cardano.Crypto.MonadSodium.Class -import Cardano.Crypto.MonadSodium.Alloc import Control.Monad.Class.MonadST (MonadST (..)) import Control.Monad.Class.MonadThrow (MonadThrow) import Control.Monad.ST.Unsafe (unsafeIOToST) @@ -33,19 +33,28 @@ import Control.Monad.ST.Unsafe (unsafeIOToST) expandHash :: forall h m proxy. - (SodiumHashAlgorithm h, MonadSodium m, MonadST m, MonadThrow m) + (SodiumHashAlgorithm h, MonadST m, MonadThrow m) => proxy h -> MLockedSizedBytes (SizeHash h) -> m (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h)) -expandHash h (MLSB sfptr) = do +expandHash = expandHashWith mlockedMalloc + +expandHashWith + :: forall h m proxy. + (SodiumHashAlgorithm h, MonadST m, MonadThrow m) + => MLockedAllocator m + -> proxy h + -> MLockedSizedBytes (SizeHash h) + -> m (MLockedSizedBytes (SizeHash h), MLockedSizedBytes (SizeHash h)) +expandHashWith allocator h (MLSB sfptr) = do withMLockedForeignPtr sfptr $ \ptr -> do - l <- mlockedAlloca size1 $ \ptr' -> do + l <- mlockedAllocaWith allocator size1 $ \ptr' -> do withLiftST $ \liftST -> liftST . unsafeIOToST $ do poke ptr' (1 :: Word8) copyMem (castPtr (plusPtr ptr' 1)) ptr size naclDigestPtr h ptr' (fromIntegral size1) - r <- mlockedAlloca size1 $ \ptr' -> do + r <- mlockedAllocaWith allocator size1 $ \ptr' -> do withLiftST $ \liftST -> liftST . unsafeIOToST $ do poke ptr' (2 :: Word8) copyMem (castPtr (plusPtr ptr' 1)) ptr size diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes.hs index 48ac366d4..a08391552 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes.hs @@ -17,6 +17,12 @@ module Cardano.Crypto.Libsodium.MLockedBytes ( traceMLSB, mlsbCompare, mlsbEq, + + mlsbNewWith, + mlsbNewZeroWith, + mlsbCopyWith, + mlsbFromByteStringWith, + mlsbFromByteStringCheckWith, ) where import Cardano.Crypto.Libsodium.MLockedBytes.Internal diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs index 69c03f2c6..1d50ac25f 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs @@ -1,12 +1,13 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingVia #-} -{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} + module Cardano.Crypto.Libsodium.MLockedBytes.Internal ( -- * The MLockesSizedBytes type MLockedSizedBytes (..), @@ -25,12 +26,18 @@ module Cardano.Crypto.Libsodium.MLockedBytes.Internal ( withMLSB, withMLSBChunk, + mlsbNewWith, + mlsbNewZeroWith, + mlsbCopyWith, + -- * Dangerous Functions traceMLSB, mlsbFromByteString, mlsbFromByteStringCheck, mlsbAsByteString, mlsbToByteString, + mlsbFromByteStringWith, + mlsbFromByteStringCheckWith, ) where import Control.DeepSeq (NFData (..)) @@ -45,11 +52,10 @@ import GHC.TypeLits (KnownNat, Nat, natVal) import NoThunks.Class (NoThunks) import Cardano.Foreign -import Cardano.Crypto.MonadSodium.Class -import Cardano.Crypto.MonadSodium.Alloc +import Cardano.Crypto.Libsodium.Memory import Cardano.Crypto.Libsodium.Memory.Internal (MLockedForeignPtr (..)) import Cardano.Crypto.Libsodium.C -import Cardano.Crypto.MEqOrd +import Cardano.Crypto.EqST import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BSI @@ -78,7 +84,7 @@ instance KnownNat n => Show (MLockedSizedBytes n) where -- hexstr = concatMap (printf "%02x") bytes -- in "MLSB " ++ hexstr -instance (MonadSodium m, MonadST m, KnownNat n) => MEq m (MLockedSizedBytes n) where +instance KnownNat n => EqST (MLockedSizedBytes n) where equalsM = mlsbEq nextPowerOf2 :: forall n. (Num n, Ord n, Bits n) => n -> n @@ -94,11 +100,11 @@ traceMLSB :: KnownNat n => MLockedSizedBytes n -> IO () traceMLSB = print {-# DEPRECATED traceMLSB "Don't leave traceMLockedForeignPtr in production" #-} -withMLSB :: forall b n m. (MonadSodium m) => MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b +withMLSB :: forall b n m. (MonadST m) => MLockedSizedBytes n -> (Ptr (SizedVoid n) -> m b) -> m b withMLSB (MLSB fptr) action = withMLockedForeignPtr fptr action withMLSBChunk :: forall b n n' m. - (MonadSodium m, MonadST m, KnownNat n, KnownNat n') + (MonadST m, KnownNat n, KnownNat n') => MLockedSizedBytes n -> Int -> (MLockedSizedBytes n' -> m b) @@ -123,9 +129,12 @@ mlsbSize mlsb = fromInteger (natVal mlsb) -- | Allocate a new 'MLockedSizedBytes'. The caller is responsible for -- deallocating it ('mlsbFinalize') when done with it. The contents of the -- memory block is undefined. -mlsbNew :: forall n m. (KnownNat n, MonadSodium m) => m (MLockedSizedBytes n) -mlsbNew = - MLSB <$> mlockedAllocForeignPtrBytes size align +mlsbNew :: forall n m. (KnownNat n, MonadST m) => m (MLockedSizedBytes n) +mlsbNew = mlsbNewWith mlockedMalloc + +mlsbNewWith :: forall n m. MLockedAllocator m -> (KnownNat n, MonadST m) => m (MLockedSizedBytes n) +mlsbNewWith allocator = + MLSB <$> mlockedAllocForeignPtrBytesWith allocator size align where size = fromInteger (natVal (Proxy @n)) align = nextPowerOf2 size @@ -133,21 +142,33 @@ mlsbNew = -- | Allocate a new 'MLockedSizedBytes', and pre-fill it with zeroes. -- The caller is responsible for deallocating it ('mlsbFinalize') when done -- with it. (See also 'mlsbNew'). -mlsbNewZero :: forall n m. (KnownNat n, MonadSodium m) => m (MLockedSizedBytes n) -mlsbNewZero = do - mlsb <- mlsbNew +mlsbNewZero :: forall n m. (KnownNat n, MonadST m) => m (MLockedSizedBytes n) +mlsbNewZero = mlsbNewZeroWith mlockedMalloc + +mlsbNewZeroWith :: forall n m. (KnownNat n, MonadST m) => MLockedAllocator m -> m (MLockedSizedBytes n) +mlsbNewZeroWith allocator = do + mlsb <- mlsbNewWith allocator mlsbZero mlsb return mlsb -- | Overwrite an existing 'MLockedSizedBytes' with zeroes. -mlsbZero :: forall n m. (KnownNat n, MonadSodium m) => MLockedSizedBytes n -> m () +mlsbZero :: forall n m. (KnownNat n, MonadST m) => MLockedSizedBytes n -> m () mlsbZero mlsb = do withMLSB mlsb $ \ptr -> zeroMem ptr (mlsbSize mlsb) -- | Create a deep mlocked copy of an 'MLockedSizedBytes'. -mlsbCopy :: forall n m. (KnownNat n, MonadSodium m) => MLockedSizedBytes n -> m (MLockedSizedBytes n) -mlsbCopy src = mlsbUseAsCPtr src $ \ptrSrc -> do - dst <- mlsbNew +mlsbCopy :: forall n m. (KnownNat n, MonadST m) + => MLockedSizedBytes n + -> m (MLockedSizedBytes n) +mlsbCopy = mlsbCopyWith mlockedMalloc + +mlsbCopyWith :: + forall n m. (KnownNat n, MonadST m) + => MLockedAllocator m + -> MLockedSizedBytes n + -> m (MLockedSizedBytes n) +mlsbCopyWith allocator src = mlsbUseAsCPtr src $ \ptrSrc -> do + dst <- mlsbNewWith allocator withMLSB dst $ \ptrDst -> do copyMem (castPtr ptrDst) (castPtr ptrSrc) (mlsbSize src) return dst @@ -160,10 +181,14 @@ mlsbCopy src = mlsbUseAsCPtr src $ \ptrSrc -> do -- 'mlsbNew' or 'mlsbNewZero' to create 'MLockedSizedBytes' values, and -- manipulate them through 'withMLSB', 'mlsbUseAsCPtr', or 'mlsbUseAsSizedPtr'. -- (See also 'mlsbFromByteStringCheck') -mlsbFromByteString :: forall n m. (KnownNat n, MonadSodium m, MonadST m) +mlsbFromByteString :: forall n m. (KnownNat n, MonadST m) => BS.ByteString -> m (MLockedSizedBytes n) -mlsbFromByteString bs = do - dst <- mlsbNew +mlsbFromByteString = mlsbFromByteStringWith mlockedMalloc + +mlsbFromByteStringWith :: forall n m. (KnownNat n, MonadST m) + => MLockedAllocator m -> BS.ByteString -> m (MLockedSizedBytes n) +mlsbFromByteStringWith allocator bs = do + dst <- mlsbNewWith allocator withMLSB dst $ \ptr -> do withLiftST $ \liftST -> liftST . unsafeIOToST $ do BS.useAsCStringLen bs $ \(ptrBS, len) -> do @@ -178,10 +203,19 @@ mlsbFromByteString bs = do -- 'mlsbNew' or 'mlsbNewZero' to create 'MLockedSizedBytes' values, and -- manipulate them through 'withMLSB', 'mlsbUseAsCPtr', or 'mlsbUseAsSizedPtr'. -- (See also 'mlsbFromByteString') -mlsbFromByteStringCheck :: forall n m. (KnownNat n, MonadSodium m, MonadST m) => BS.ByteString -> m (Maybe (MLockedSizedBytes n)) -mlsbFromByteStringCheck bs +mlsbFromByteStringCheck :: forall n m. (KnownNat n, MonadST m) + => BS.ByteString + -> m (Maybe (MLockedSizedBytes n)) +mlsbFromByteStringCheck = mlsbFromByteStringCheckWith mlockedMalloc + +mlsbFromByteStringCheckWith :: + forall n m. (KnownNat n, MonadST m) + => MLockedAllocator m + -> BS.ByteString + -> m (Maybe (MLockedSizedBytes n)) +mlsbFromByteStringCheckWith allocator bs | BS.length bs /= size = return Nothing - | otherwise = Just <$> mlsbFromByteString bs + | otherwise = Just <$> mlsbFromByteStringWith allocator bs where size :: Int size = fromInteger (natVal (Proxy @n)) @@ -200,7 +234,7 @@ mlsbAsByteString mlsb@(MLSB (SFP fptr)) = BSI.PS (castForeignPtr fptr) 0 size -- | /Note:/ this function will leak mlocked memory to the Haskell heap -- and should not be used in production code. -mlsbToByteString :: forall n m. (KnownNat n, MonadSodium m, MonadST m) => MLockedSizedBytes n -> m BS.ByteString +mlsbToByteString :: forall n m. (KnownNat n, MonadST m) => MLockedSizedBytes n -> m BS.ByteString mlsbToByteString mlsb = withMLSB mlsb $ \ptr -> withLiftST $ \liftST -> liftST . unsafeIOToST $ BS.packCStringLen (castPtr ptr, size) @@ -212,7 +246,7 @@ mlsbToByteString mlsb = -- to never copy the contents of the 'MLockedSizedBytes' value into managed -- memory through the raw pointer, because that would violate the -- secure-forgetting property of mlocked memory. -mlsbUseAsCPtr :: MonadSodium m => MLockedSizedBytes n -> (Ptr Word8 -> m r) -> m r +mlsbUseAsCPtr :: MonadST m => MLockedSizedBytes n -> (Ptr Word8 -> m r) -> m r mlsbUseAsCPtr (MLSB x) k = withMLockedForeignPtr x (k . castPtr) @@ -220,18 +254,18 @@ mlsbUseAsCPtr (MLSB x) k = -- should be taken to never copy the contents of the 'MLockedSizedBytes' value -- into managed memory through the sized pointer, because that would violate -- the secure-forgetting property of mlocked memory. -mlsbUseAsSizedPtr :: forall n r m. (MonadSodium m) => MLockedSizedBytes n -> (SizedPtr n -> m r) -> m r +mlsbUseAsSizedPtr :: forall n r m. (MonadST m) => MLockedSizedBytes n -> (SizedPtr n -> m r) -> m r mlsbUseAsSizedPtr (MLSB x) k = withMLockedForeignPtr x (k . SizedPtr . castPtr) -- | Calls 'finalizeMLockedForeignPtr' on underlying pointer. -- This function invalidates argument. -- -mlsbFinalize :: MonadSodium m => MLockedSizedBytes n -> m () +mlsbFinalize :: MonadST m => MLockedSizedBytes n -> m () mlsbFinalize (MLSB ptr) = finalizeMLockedForeignPtr ptr -- | 'compareM' on 'MLockedSizedBytes' -mlsbCompare :: forall n m. (MonadSodium m, MonadST m, KnownNat n) => MLockedSizedBytes n -> MLockedSizedBytes n -> m Ordering +mlsbCompare :: forall n m. (MonadST m, KnownNat n) => MLockedSizedBytes n -> MLockedSizedBytes n -> m Ordering mlsbCompare (MLSB x) (MLSB y) = withMLockedForeignPtr x $ \x' -> withMLockedForeignPtr y $ \y' -> do @@ -241,5 +275,5 @@ mlsbCompare (MLSB x) (MLSB y) = size = fromInteger $ natVal (Proxy @n) -- | 'equalsM' on 'MLockedSizedBytes' -mlsbEq :: forall n m. (MonadSodium m, MonadST m, KnownNat n) => MLockedSizedBytes n -> MLockedSizedBytes n -> m Bool +mlsbEq :: forall n m. (MonadST m, KnownNat n) => MLockedSizedBytes n -> MLockedSizedBytes n -> m Bool mlsbEq a b = (== EQ) <$> mlsbCompare a b diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs new file mode 100644 index 000000000..0677a9a13 --- /dev/null +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs @@ -0,0 +1,84 @@ +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE StandaloneDeriving #-} + +module Cardano.Crypto.Libsodium.MLockedSeed +where + +import Cardano.Crypto.Libsodium.MLockedBytes ( + MLockedSizedBytes, + mlsbCopyWith, + mlsbFinalize, + mlsbNewWith, + mlsbNewZeroWith, + mlsbUseAsCPtr, + mlsbUseAsSizedPtr, + ) +import Cardano.Crypto.Libsodium.Memory ( + MLockedAllocator, + mlockedMalloc, + ) +import Cardano.Crypto.EqST ( + EqST (..), + ) +import Cardano.Foreign (SizedPtr) +import Control.DeepSeq (NFData) +import Control.Monad.Class.MonadST (MonadST) +import Data.Word (Word8) +import Foreign.Ptr (Ptr) +import GHC.TypeNats (KnownNat) +import NoThunks.Class (NoThunks) + +-- | A seed of size @n@, stored in mlocked memory. This is required to prevent +-- the seed from leaking to disk via swapping and reclaiming or scanning memory +-- after its content has been moved. +newtype MLockedSeed n = MLockedSeed {mlockedSeedMLSB :: MLockedSizedBytes n} + deriving (NFData, NoThunks) + +deriving via + MLockedSizedBytes n + instance + KnownNat n => EqST (MLockedSeed n) + +withMLockedSeedAsMLSB + :: Functor m + => (MLockedSizedBytes n -> m (MLockedSizedBytes n)) + -> MLockedSeed n + -> m (MLockedSeed n) +withMLockedSeedAsMLSB action = + fmap MLockedSeed . action . mlockedSeedMLSB + +mlockedSeedCopy :: (KnownNat n, MonadST m) => MLockedSeed n -> m (MLockedSeed n) +mlockedSeedCopy = mlockedSeedCopyWith mlockedMalloc + +mlockedSeedCopyWith + :: (KnownNat n, MonadST m) + => MLockedAllocator m + -> MLockedSeed n + -> m (MLockedSeed n) +mlockedSeedCopyWith allocator = withMLockedSeedAsMLSB (mlsbCopyWith allocator) + +mlockedSeedNew :: (KnownNat n, MonadST m) => m (MLockedSeed n) +mlockedSeedNew = mlockedSeedNewWith mlockedMalloc + +mlockedSeedNewWith :: (KnownNat n, MonadST m) => MLockedAllocator m -> m (MLockedSeed n) +mlockedSeedNewWith allocator = + MLockedSeed <$> mlsbNewWith allocator + +mlockedSeedNewZero :: (KnownNat n, MonadST m) => m (MLockedSeed n) +mlockedSeedNewZero = mlockedSeedNewZeroWith mlockedMalloc + +mlockedSeedNewZeroWith :: (KnownNat n, MonadST m) => MLockedAllocator m -> m (MLockedSeed n) +mlockedSeedNewZeroWith allocator = + MLockedSeed <$> mlsbNewZeroWith allocator + +mlockedSeedFinalize :: (MonadST m) => MLockedSeed n -> m () +mlockedSeedFinalize = mlsbFinalize . mlockedSeedMLSB + +mlockedSeedUseAsCPtr :: (MonadST m) => MLockedSeed n -> (Ptr Word8 -> m b) -> m b +mlockedSeedUseAsCPtr seed = mlsbUseAsCPtr (mlockedSeedMLSB seed) + +mlockedSeedUseAsSizedPtr :: (MonadST m) => MLockedSeed n -> (SizedPtr n -> m b) -> m b +mlockedSeedUseAsSizedPtr seed = mlsbUseAsSizedPtr (mlockedSeedMLSB seed) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index 1e830c403..3c04e37e5 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -4,7 +4,32 @@ module Cardano.Crypto.Libsodium.Memory ( withMLockedForeignPtr, finalizeMLockedForeignPtr, traceMLockedForeignPtr, + + -- * MLocked allocations mlockedMalloc, + MLockedAllocator (..), + AllocatorEvent(..), + getAllocatorEvent, + + mlockedAlloca, + mlockedAllocaSized, + mlockedAllocForeignPtr, + mlockedAllocForeignPtrBytes, + + -- * Allocations using an explicit allocator + mlockedAllocaWith, + mlockedAllocaSizedWith, + mlockedAllocForeignPtrWith, + mlockedAllocForeignPtrBytesWith, + + -- * Unmanaged memory, generalized to 'MonadST' + zeroMem, + copyMem, + allocaBytes, + + -- * ByteString memory access, generalized to 'MonadST' + useByteStringAsCStringLen, + packByteStringCStringLen, ) where import Cardano.Crypto.Libsodium.Memory.Internal diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 1bc425672..79e4162dd 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -1,40 +1,78 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingVia #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE CPP #-} -{-# OPTIONS_GHC -fprof-auto #-} module Cardano.Crypto.Libsodium.Memory.Internal ( -- * High-level memory management MLockedForeignPtr (..), withMLockedForeignPtr, finalizeMLockedForeignPtr, traceMLockedForeignPtr, + + -- * MLocked allocations mlockedMalloc, - -- * Low-level memory function - sodiumMalloc, - sodiumFree, + MLockedAllocator (..), + AllocatorEvent(..), + getAllocatorEvent, + + mlockedAlloca, + mlockedAllocaSized, + mlockedAllocForeignPtr, + mlockedAllocForeignPtrBytes, + + -- * Allocations using an explicit allocator + mlockedAllocaWith, + mlockedAllocaSizedWith, + mlockedAllocForeignPtrWith, + mlockedAllocForeignPtrBytesWith, + + -- * Unmanaged memory, generalized to 'MonadST' + zeroMem, + copyMem, + allocaBytes, + + -- * ByteString memory access, generalized to 'MonadST' + useByteStringAsCStringLen, + packByteStringCStringLen, + + -- * Helper + unsafeIOToMonadST ) where import Control.DeepSeq (NFData (..), rwhnf) -import Control.Exception (mask_) -import Control.Monad (when) +import Control.Exception (Exception, mask_) +import Control.Monad (when, void) +import Control.Monad.Class.MonadST +import Control.Monad.Class.MonadThrow (MonadThrow (bracket)) +import Control.Monad.ST +import Control.Monad.ST.Unsafe (unsafeIOToST, unsafeSTToIO) +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Unsafe as BS import Data.Coerce (coerce) -import Data.Proxy (Proxy (..)) +import Data.Typeable +import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) +import Foreign.C.String (CStringLen) import Foreign.C.Types (CSize (..)) -import Foreign.Ptr (Ptr, nullPtr) -import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, finalizeForeignPtr) import Foreign.Concurrent (newForeignPtr) -import Foreign.Storable (Storable (peek)) +import Foreign.ForeignPtr (ForeignPtr, finalizeForeignPtr, touchForeignPtr) +import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) +import qualified Foreign.Marshal.Alloc as Foreign import Foreign.Marshal.Utils (fillBytes) -import GHC.TypeLits (KnownNat, natVal) +import Foreign.Ptr (Ptr, nullPtr, castPtr) +import Foreign.Storable (Storable (peek), sizeOf, alignment) import GHC.IO.Exception (ioException) +import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) import System.IO.Unsafe (unsafePerformIO) import Cardano.Crypto.Libsodium.C +import Cardano.Foreign (c_memset, c_memcpy, SizedPtr (..)) import Cardano.Memory.Pool (initPool, grabNextBlock, Pool) -- | Foreign pointer to securely allocated memory. @@ -44,24 +82,30 @@ newtype MLockedForeignPtr a = SFP { _unwrapMLockedForeignPtr :: ForeignPtr a } instance NFData (MLockedForeignPtr a) where rnf = rwhnf . _unwrapMLockedForeignPtr -withMLockedForeignPtr :: forall a b. MLockedForeignPtr a -> (Ptr a -> IO b) -> IO b -withMLockedForeignPtr = coerce (withForeignPtr @a @b) +withMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> (Ptr a -> m b) -> m b +withMLockedForeignPtr (SFP fptr) f = do + r <- f (unsafeForeignPtrToPtr fptr) + r <$ unsafeIOToMonadST (touchForeignPtr fptr) + +finalizeMLockedForeignPtr :: MonadST m => MLockedForeignPtr a -> m () +finalizeMLockedForeignPtr (SFP fptr) = withLiftST $ \lift -> + (lift . unsafeIOToST) (finalizeForeignPtr fptr) -finalizeMLockedForeignPtr :: forall a. MLockedForeignPtr a -> IO () -finalizeMLockedForeignPtr = coerce (finalizeForeignPtr @a) +{-# WARNING traceMLockedForeignPtr "Do not use traceMLockedForeignPtr in production" #-} -traceMLockedForeignPtr :: (Storable a, Show a) => MLockedForeignPtr a -> IO () +traceMLockedForeignPtr :: (Storable a, Show a, MonadST m) => MLockedForeignPtr a -> m () traceMLockedForeignPtr fptr = withMLockedForeignPtr fptr $ \ptr -> do - a <- peek ptr - print a + a <- unsafeIOToMonadST (peek ptr) + traceShowM a -{-# DEPRECATED traceMLockedForeignPtr "Don't leave traceMLockedForeignPtr in production" #-} +unsafeIOToMonadST :: MonadST m => IO a -> m a +unsafeIOToMonadST action = withLiftST ($ unsafeIOToST action) -makeMLockedPool :: forall n. KnownNat n => IO (Pool n) +makeMLockedPool :: forall n s. KnownNat n => ST s (Pool n s) makeMLockedPool = do initPool (max 1 . fromIntegral $ 4096 `div` natVal (Proxy @n) `div` 64) - (\size -> mask_ $ do + (\size -> unsafeIOToST $ mask_ $ do ptr <- sodiumMalloc (fromIntegral size) newForeignPtr ptr (sodiumFree ptr (fromIntegral size)) ) @@ -72,39 +116,50 @@ makeMLockedPool = do eraseMem :: forall n a. KnownNat n => Proxy n -> Ptr a -> IO () eraseMem proxy ptr = fillBytes ptr 0xff (fromIntegral $ natVal proxy) -mlockedPool32 :: Pool 32 -mlockedPool32 = unsafePerformIO makeMLockedPool +mlockedPool32 :: Pool 32 RealWorld +mlockedPool32 = unsafePerformIO $ stToIO makeMLockedPool {-# NOINLINE mlockedPool32 #-} -mlockedPool64 :: Pool 64 -mlockedPool64 = unsafePerformIO makeMLockedPool +mlockedPool64 :: Pool 64 RealWorld +mlockedPool64 = unsafePerformIO $ stToIO makeMLockedPool {-# NOINLINE mlockedPool64 #-} -mlockedPool128 :: Pool 128 -mlockedPool128 = unsafePerformIO makeMLockedPool +mlockedPool128 :: Pool 128 RealWorld +mlockedPool128 = unsafePerformIO $ stToIO makeMLockedPool {-# NOINLINE mlockedPool128 #-} -mlockedPool256 :: Pool 256 -mlockedPool256 = unsafePerformIO makeMLockedPool +mlockedPool256 :: Pool 256 RealWorld +mlockedPool256 = unsafePerformIO $ stToIO makeMLockedPool {-# NOINLINE mlockedPool256 #-} -mlockedPool512 :: Pool 512 -mlockedPool512 = unsafePerformIO makeMLockedPool +mlockedPool512 :: Pool 512 RealWorld +mlockedPool512 = unsafePerformIO $ stToIO makeMLockedPool {-# NOINLINE mlockedPool512 #-} -mlockedMalloc :: CSize -> IO (MLockedForeignPtr a) -mlockedMalloc size = SFP <$> do +data AllocatorException = + AllocatorNoTracer + | AllocatorNoGenerator + deriving Show + +instance Exception AllocatorException + +mlockedMalloc :: MonadST m => MLockedAllocator m +mlockedMalloc = + MLockedAllocator { mlAllocate = \ size -> withLiftST ($ unsafeIOToST (mlockedMallocIO size)) } + +mlockedMallocIO :: CSize -> IO (MLockedForeignPtr a) +mlockedMallocIO size = SFP <$> do if | size <= 32 -> do - coerce $ grabNextBlock mlockedPool32 + coerce $ stToIO $ grabNextBlock mlockedPool32 | size <= 64 -> do - coerce $ grabNextBlock mlockedPool64 + coerce $ stToIO $ grabNextBlock mlockedPool64 | size <= 128 -> do - coerce $ grabNextBlock mlockedPool128 + coerce $ stToIO $ grabNextBlock mlockedPool128 | size <= 256 -> do - coerce $ grabNextBlock mlockedPool256 + coerce $ stToIO $ grabNextBlock mlockedPool256 | size <= 512 -> do - coerce $ grabNextBlock mlockedPool512 + coerce $ stToIO $ grabNextBlock mlockedPool512 | otherwise -> do mask_ $ do ptr <- sodiumMalloc size @@ -130,3 +185,98 @@ sodiumFree ptr size = do errno <- getErrno ioException $ errnoToIOError "c_sodium_munlock" errno Nothing Nothing c_sodium_free ptr + +zeroMem :: MonadST m => Ptr a -> CSize -> m () +zeroMem ptr size = unsafeIOToMonadST . void $ c_memset (castPtr ptr) 0 size + +copyMem :: MonadST m => Ptr a -> Ptr a -> CSize -> m () +copyMem dst src size = unsafeIOToMonadST . void $ c_memcpy (castPtr dst) (castPtr src) size + +allocaBytes :: Int -> (Ptr a -> ST s b) -> ST s b +allocaBytes size f = + unsafeIOToST $ Foreign.allocaBytes size (unsafeSTToIO . f) + +useByteStringAsCStringLen :: ByteString -> (CStringLen -> ST s a) -> ST s a +useByteStringAsCStringLen bs f = + allocaBytes (BS.length bs + 1) $ \buf -> do + len <- unsafeIOToST $ BS.unsafeUseAsCStringLen bs $ \(ptr, len) -> + len <$ copyMem buf ptr (fromIntegral len) + f (buf, len) + +packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString +packByteStringCStringLen (ptr, len) = + withLiftST $ \lift -> lift . unsafeIOToST $ BS.packCStringLen (ptr, len) + +data AllocatorEvent where + AllocatorEvent :: (Show e, Typeable e) => e -> AllocatorEvent + +instance Show AllocatorEvent where + show (AllocatorEvent e) = "(AllocatorEvent " ++ show e ++ ")" + +getAllocatorEvent :: forall e. Typeable e => AllocatorEvent -> Maybe e +getAllocatorEvent (AllocatorEvent e) = cast e + +newtype MLockedAllocator m = + MLockedAllocator + { mlAllocate :: forall a. CSize -> m (MLockedForeignPtr a) + } + +mlockedAllocaSized :: forall m n b. (MonadST m, MonadThrow m, KnownNat n) => (SizedPtr n -> m b) -> m b +mlockedAllocaSized = mlockedAllocaSizedWith mlockedMalloc + +mlockedAllocaSizedWith :: + forall m n b. (MonadST m, MonadThrow m, KnownNat n) + => MLockedAllocator m + -> (SizedPtr n -> m b) + -> m b +mlockedAllocaSizedWith allocator k = mlockedAllocaWith allocator size (k . SizedPtr) where + size :: CSize + size = fromInteger (natVal (Proxy @n)) + +mlockedAllocForeignPtrBytes :: MonadST m => CSize -> CSize -> m (MLockedForeignPtr a) +mlockedAllocForeignPtrBytes = mlockedAllocForeignPtrBytesWith mlockedMalloc + +mlockedAllocForeignPtrBytesWith :: MLockedAllocator m -> CSize -> CSize -> m (MLockedForeignPtr a) +mlockedAllocForeignPtrBytesWith allocator size align = do + mlAllocate allocator size' + where + size' :: CSize + size' + | m == 0 = size + | otherwise = (q + 1) * align + where + (q,m) = size `quotRem` align + +mlockedAllocForeignPtr :: forall a m . (MonadST m, Storable a) => m (MLockedForeignPtr a) +mlockedAllocForeignPtr = mlockedAllocForeignPtrWith mlockedMalloc + +mlockedAllocForeignPtrWith :: + forall a m. Storable a + => MLockedAllocator m + -> m (MLockedForeignPtr a) +mlockedAllocForeignPtrWith allocator = + mlockedAllocForeignPtrBytesWith allocator size align + where + dummy :: a + dummy = undefined + + size :: CSize + size = fromIntegral $ sizeOf dummy + + align :: CSize + align = fromIntegral $ alignment dummy + +mlockedAlloca :: forall a b m. (MonadST m, MonadThrow m) => CSize -> (Ptr a -> m b) -> m b +mlockedAlloca = mlockedAllocaWith mlockedMalloc + +mlockedAllocaWith :: + forall a b m. (MonadThrow m, MonadST m) + => MLockedAllocator m + -> CSize + -> (Ptr a -> m b) + -> m b +mlockedAllocaWith allocator size = + bracket alloc free . flip withMLockedForeignPtr + where + alloc = mlAllocate allocator size + free = finalizeMLockedForeignPtr diff --git a/cardano-crypto-class/src/Cardano/Crypto/MLockedSeed.hs b/cardano-crypto-class/src/Cardano/Crypto/MLockedSeed.hs deleted file mode 100644 index e5f9fc9ce..000000000 --- a/cardano-crypto-class/src/Cardano/Crypto/MLockedSeed.hs +++ /dev/null @@ -1,64 +0,0 @@ -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE MultiParamTypeClasses #-} -module Cardano.Crypto.MLockedSeed -where - -import Cardano.Crypto.MonadSodium - ( MLockedSizedBytes - , MonadSodium (..) - , mlsbCopy - , mlsbNew - , mlsbNewZero - , mlsbFinalize - , mlsbUseAsCPtr - , mlsbUseAsSizedPtr - , MEq (..) - ) -import Cardano.Foreign (SizedPtr) -import GHC.TypeNats (KnownNat) -import Control.DeepSeq (NFData) -import NoThunks.Class (NoThunks) -import Foreign.Ptr (Ptr) -import Data.Word (Word8) -import Control.Monad.Class.MonadST (MonadST) - --- | A seed of size @n@, stored in mlocked memory. This is required to prevent --- the seed from leaking to disk via swapping and reclaiming or scanning memory --- after its content has been moved. -newtype MLockedSeed n = - MLockedSeed { mlockedSeedMLSB :: MLockedSizedBytes n } - deriving (NFData, NoThunks) - -deriving via (MLockedSizedBytes n) - instance (MonadSodium m, MonadST m, KnownNat n) => MEq m (MLockedSeed n) - -withMLockedSeedAsMLSB :: Functor m - => (MLockedSizedBytes n -> m (MLockedSizedBytes n)) - -> MLockedSeed n - -> m (MLockedSeed n) -withMLockedSeedAsMLSB action = - fmap MLockedSeed . action . mlockedSeedMLSB - -mlockedSeedCopy :: (KnownNat n, MonadSodium m) => MLockedSeed n -> m (MLockedSeed n) -mlockedSeedCopy = - withMLockedSeedAsMLSB mlsbCopy - -mlockedSeedNew :: (KnownNat n, MonadSodium m) => m (MLockedSeed n) -mlockedSeedNew = - MLockedSeed <$> mlsbNew - -mlockedSeedNewZero :: (KnownNat n, MonadSodium m) => m (MLockedSeed n) -mlockedSeedNewZero = - MLockedSeed <$> mlsbNewZero - -mlockedSeedFinalize :: (MonadSodium m) => MLockedSeed n -> m () -mlockedSeedFinalize = mlsbFinalize . mlockedSeedMLSB - -mlockedSeedUseAsCPtr :: (MonadSodium m) => MLockedSeed n -> (Ptr Word8 -> m b) -> m b -mlockedSeedUseAsCPtr seed = mlsbUseAsCPtr (mlockedSeedMLSB seed) - -mlockedSeedUseAsSizedPtr :: (MonadSodium m) => MLockedSeed n -> (SizedPtr n -> m b) -> m b -mlockedSeedUseAsSizedPtr seed = mlsbUseAsSizedPtr (mlockedSeedMLSB seed) diff --git a/cardano-crypto-class/src/Cardano/Crypto/MonadSodium.hs b/cardano-crypto-class/src/Cardano/Crypto/MonadSodium.hs deleted file mode 100644 index 0085fdb7c..000000000 --- a/cardano-crypto-class/src/Cardano/Crypto/MonadSodium.hs +++ /dev/null @@ -1,64 +0,0 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} - --- We need this so that we can forward the deprecated traceMLockedForeignPtr -{-# OPTIONS_GHC -Wno-deprecations #-} - --- | The Libsodium API generalized to fit arbitrary-ish Monads. --- --- The purpose of this module is to provide a drop-in replacement for the plain --- 'Cardano.Crypto.Libsodium' module, but such that the Monad in which some --- essential actions run can be mocked, rather than forcing it to be 'IO'. --- --- It may also be used to provide Libsodium functionality in monad stacks that --- have IO at the bottom, but decorate certain Libsodium operations with --- additional effects, e.g. logging mlocked memory access. -module Cardano.Crypto.MonadSodium -( - -- * MonadSodium class - MonadSodium (..), - - -- * Re-exported types - MLockedForeignPtr, - MLockedSizedBytes, - - -- * Monadic Eq and Ord - MEq (..), - nequalsM, - (==!), (!=!), - PureMEq (..), - - -- * Memory management - mlockedAllocaSized, - mlockedAllocForeignPtr, - mlockedAllocForeignPtrBytes, - - -- * MLockedSizedBytes operations - mlsbNew, - mlsbZero, - mlsbNewZero, - mlsbCopy, - mlsbFinalize, - mlsbToByteString, - mlsbAsByteString, - mlsbFromByteString, - mlsbFromByteStringCheck, - mlsbUseAsSizedPtr, - mlsbUseAsCPtr, - mlsbCompare, - mlsbEq, - - -- * Hashing - SodiumHashAlgorithm (..), - expandHash, - digestMLockedStorable, - digestMLockedBS, -) -where - -import Cardano.Crypto.MonadSodium.Class -import Cardano.Crypto.MonadSodium.Alloc -import Cardano.Crypto.Libsodium.Hash -import Cardano.Crypto.Libsodium.MLockedBytes -import Cardano.Crypto.MEqOrd diff --git a/cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Alloc.hs b/cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Alloc.hs deleted file mode 100644 index 5eecbafb5..000000000 --- a/cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Alloc.hs +++ /dev/null @@ -1,77 +0,0 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} - --- We need this so that we can forward the deprecated traceMLockedForeignPtr -{-# OPTIONS_GHC -Wno-deprecations #-} - --- | The Libsodium API generalized to fit arbitrary-ish Monads. --- --- The purpose of this module is to provide a drop-in replacement for the plain --- 'Cardano.Crypto.Libsodium' module, but such that the Monad in which some --- essential actions run can be mocked, rather than forcing it to be 'IO'. --- --- It may also be used to provide Libsodium functionality in monad stacks that --- have IO at the bottom, but decorate certain Libsodium operations with --- additional effects, e.g. logging mlocked memory access. -module Cardano.Crypto.MonadSodium.Alloc -( - MonadSodium (..), - mlockedAlloca, - mlockedAllocaSized, - mlockedAllocForeignPtr, - mlockedAllocForeignPtrBytes, - - -- * Re-exports from plain Libsodium module - NaCl.MLockedForeignPtr, -) -where - -import Cardano.Crypto.MonadSodium.Class -import Control.Monad.Class.MonadThrow (MonadThrow, bracket) - -import qualified Cardano.Crypto.Libsodium.Memory as NaCl - -import Cardano.Foreign (SizedPtr (..)) - -import GHC.TypeLits (KnownNat, natVal) -import Foreign.Storable (Storable (..)) -import Foreign.C.Types (CSize) -import Foreign.Ptr (Ptr) -import Data.Proxy (Proxy (..)) - -mlockedAllocaSized :: forall m n b. (MonadSodium m, MonadThrow m, KnownNat n) => (SizedPtr n -> m b) -> m b -mlockedAllocaSized k = mlockedAlloca size (k . SizedPtr) where - size :: CSize - size = fromInteger (natVal (Proxy @n)) - -mlockedAllocForeignPtrBytes :: (MonadSodium m) => CSize -> CSize -> m (MLockedForeignPtr a) -mlockedAllocForeignPtrBytes size align = do - mlockedMalloc size' - where - size' :: CSize - size' - | m == 0 = size - | otherwise = (q + 1) * align - where - (q,m) = size `quotRem` align - -mlockedAllocForeignPtr :: forall a m . (MonadSodium m, Storable a) => m (MLockedForeignPtr a) -mlockedAllocForeignPtr = - mlockedAllocForeignPtrBytes size align - where - dummy :: a - dummy = undefined - - size :: CSize - size = fromIntegral $ sizeOf dummy - - align :: CSize - align = fromIntegral $ alignment dummy - -mlockedAlloca :: forall a b m. (MonadSodium m, MonadThrow m) => CSize -> (Ptr a -> m b) -> m b -mlockedAlloca size = - bracket alloc free . flip withMLockedForeignPtr - where - alloc = mlockedMalloc size - free = finalizeMLockedForeignPtr diff --git a/cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Class.hs b/cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Class.hs deleted file mode 100644 index b9a0cde56..000000000 --- a/cardano-crypto-class/src/Cardano/Crypto/MonadSodium/Class.hs +++ /dev/null @@ -1,62 +0,0 @@ -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE FlexibleInstances #-} - --- We need this so that we can forward the deprecated traceMLockedForeignPtr -{-# OPTIONS_GHC -Wno-deprecations #-} - --- | The Libsodium API generalized to fit arbitrary-ish Monads. --- --- The purpose of this module is to provide a drop-in replacement for the plain --- 'Cardano.Crypto.Libsodium' module, but such that the Monad in which some --- essential actions run can be mocked, rather than forcing it to be 'IO'. --- --- It may also be used to provide Libsodium functionality in monad stacks that --- have IO at the bottom, but decorate certain Libsodium operations with --- additional effects, e.g. logging mlocked memory access. -module Cardano.Crypto.MonadSodium.Class -( - MonadSodium (..), - - -- * Re-exports from plain Libsodium module - NaCl.MLockedForeignPtr, -) -where - -import Cardano.Crypto.Libsodium.Memory.Internal (MLockedForeignPtr (..)) - -import qualified Cardano.Crypto.Libsodium.Memory as NaCl -import Control.Monad (void) - -import Cardano.Foreign (c_memset, c_memcpy) - -import Foreign.Ptr (Ptr, castPtr) -import Foreign.Storable (Storable) -import Foreign.C.Types (CSize) - -{-# DEPRECATED traceMLockedForeignPtr "Do not use traceMLockedForeignPtr in production" #-} - --- | Primitive operations on unmanaged mlocked memory. --- These are all implemented in 'IO' underneath, but should morally be in 'ST'. --- There are two use cases for this: --- - Running mlocked-memory operations in a mocking context (e.g. 'IOSim') for --- testing purposes. --- - Running mlocked-memory operations directly on some monad stack with 'IO' --- at the bottom. -class Monad m => MonadSodium m where - withMLockedForeignPtr :: forall a b. MLockedForeignPtr a -> (Ptr a -> m b) -> m b - finalizeMLockedForeignPtr :: forall a. MLockedForeignPtr a -> m () - traceMLockedForeignPtr :: (Storable a, Show a) => MLockedForeignPtr a -> m () - mlockedMalloc :: CSize -> m (MLockedForeignPtr a) - zeroMem :: Ptr a -> CSize -> m () - copyMem :: Ptr a -> Ptr a -> CSize -> m () - -instance MonadSodium IO where - withMLockedForeignPtr = NaCl.withMLockedForeignPtr - finalizeMLockedForeignPtr = NaCl.finalizeMLockedForeignPtr - traceMLockedForeignPtr = NaCl.traceMLockedForeignPtr - mlockedMalloc = NaCl.mlockedMalloc - zeroMem ptr size = void $ c_memset (castPtr ptr) 0 size - copyMem dst src size = void $ c_memcpy (castPtr dst) (castPtr src) size diff --git a/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs b/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs index baa748f2b..283b70f55 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/PinnedSizedBytes.hs @@ -38,7 +38,8 @@ import Data.Kind (Type) import Control.DeepSeq (NFData) import Control.Monad.ST (runST) import Control.Monad.ST.Unsafe (unsafeIOToST) -import Control.Monad.Primitive (PrimMonad, primitive_, touch) +import Control.Monad.Class.MonadST +import Control.Monad.Primitive (primitive_, touch) import Data.Primitive.ByteArray ( ByteArray (..) , MutableByteArray (..) @@ -237,7 +238,7 @@ instance KnownNat n => Storable (PinnedSizedBytes n) where {-# INLINE psbUseAsCPtr #-} psbUseAsCPtr :: forall (n :: Nat) (r :: Type) (m :: Type -> Type) . - (PrimMonad m) => + (MonadST m) => PinnedSizedBytes n -> (Ptr Word8 -> m r) -> m r @@ -260,7 +261,7 @@ psbUseAsCPtr (PSB ba) = runAndTouch ba {-# INLINE psbUseAsCPtrLen #-} psbUseAsCPtrLen :: forall (n :: Nat) (r :: Type) (m :: Type -> Type) . - (KnownNat n, PrimMonad m) => + (KnownNat n, MonadST m) => PinnedSizedBytes n -> (Ptr Word8 -> CSize -> m r) -> m r @@ -275,20 +276,20 @@ psbUseAsCPtrLen (PSB ba) f = do {-# INLINE psbUseAsSizedPtr #-} psbUseAsSizedPtr :: forall (n :: Nat) (r :: Type) (m :: Type -> Type) . - (PrimMonad m) => + (MonadST m) => PinnedSizedBytes n -> (SizedPtr n -> m r) -> m r -psbUseAsSizedPtr (PSB ba) k = do +psbUseAsSizedPtr (PSB ba) k = withLiftST $ \lift -> do r <- k (SizedPtr $ castPtr $ byteArrayContents ba) - r <$ touch ba + r <$ lift (touch ba) -- | As 'psbCreateResult', but presumes that no useful value is produced: that -- is, the function argument is run only for its side effects. {-# INLINE psbCreate #-} psbCreate :: forall (n :: Nat) (m :: Type -> Type) . - (KnownNat n, PrimMonad m) => + (KnownNat n, MonadST m) => (Ptr Word8 -> m ()) -> m (PinnedSizedBytes n) psbCreate f = fst <$> psbCreateResult f @@ -298,7 +299,7 @@ psbCreate f = fst <$> psbCreateResult f {-# INLINE psbCreateLen #-} psbCreateLen :: forall (n :: Nat) (m :: Type -> Type) . - (KnownNat n, PrimMonad m) => + (KnownNat n, MonadST m) => (Ptr Word8 -> CSize -> m ()) -> m (PinnedSizedBytes n) psbCreateLen f = fst <$> psbCreateResultLen f @@ -321,7 +322,7 @@ psbCreateLen f = fst <$> psbCreateResultLen f {-# INLINE psbCreateResult #-} psbCreateResult :: forall (n :: Nat) (r :: Type) (m :: Type -> Type) . - (KnownNat n, PrimMonad m) => + (KnownNat n, MonadST m) => (Ptr Word8 -> m r) -> m (PinnedSizedBytes n, r) psbCreateResult f = psbCreateResultLen (\p _ -> f p) @@ -341,14 +342,14 @@ psbCreateResult f = psbCreateResultLen (\p _ -> f p) {-# INLINE psbCreateResultLen #-} psbCreateResultLen :: forall (n :: Nat) (r :: Type) (m :: Type -> Type). - (KnownNat n, PrimMonad m) => + (KnownNat n, MonadST m) => (Ptr Word8 -> CSize -> m r) -> m (PinnedSizedBytes n, r) -psbCreateResultLen f = do +psbCreateResultLen f = withLiftST $ \lift -> do let len :: Int = fromIntegral . natVal $ Proxy @n - mba <- newPinnedByteArray len + mba <- lift (newPinnedByteArray len) res <- f (mutableByteArrayContents mba) (fromIntegral len) - arr <- unsafeFreezeByteArray mba + arr <- lift (unsafeFreezeByteArray mba) pure (PSB arr, res) -- | As 'psbCreateSizedResult', but presumes that no useful value is produced: @@ -356,7 +357,7 @@ psbCreateResultLen f = do {-# INLINE psbCreateSized #-} psbCreateSized :: forall (n :: Nat) (m :: Type -> Type) . - (KnownNat n, PrimMonad m) => + (KnownNat n, MonadST m) => (SizedPtr n -> m ()) -> m (PinnedSizedBytes n) psbCreateSized k = psbCreate (k . SizedPtr . castPtr) @@ -367,7 +368,7 @@ psbCreateSized k = psbCreate (k . SizedPtr . castPtr) {-# INLINE psbCreateSizedResult #-} psbCreateSizedResult :: forall (n :: Nat) (r :: Type) (m :: Type -> Type) . - (KnownNat n, PrimMonad m) => + (KnownNat n, MonadST m) => (SizedPtr n -> m r) -> m (PinnedSizedBytes n, r) psbCreateSizedResult f = psbCreateResult (f . SizedPtr . castPtr) @@ -405,10 +406,10 @@ die fun problem = error $ "PinnedSizedBytes." ++ fun ++ ": " ++ problem {-# INLINE runAndTouch #-} runAndTouch :: forall (a :: Type) (m :: Type -> Type) . - (PrimMonad m) => + (MonadST m) => ByteArray -> (Ptr Word8 -> m a) -> m a -runAndTouch ba f = do +runAndTouch ba f = withLiftST $ \lift -> do r <- f (byteArrayContents ba) - r <$ touch ba + r <$ lift (touch ba) diff --git a/cardano-crypto-tests/cardano-crypto-tests.cabal b/cardano-crypto-tests/cardano-crypto-tests.cabal index 2be302d9d..040043a14 100644 --- a/cardano-crypto-tests/cardano-crypto-tests.cabal +++ b/cardano-crypto-tests/cardano-crypto-tests.cabal @@ -54,7 +54,6 @@ library Test.Crypto.VRF Test.Crypto.Regressions Test.Crypto.Instances - Cardano.Crypto.KES.ForgetMock Bench.Crypto.DSIGN Bench.Crypto.VRF Bench.Crypto.KES @@ -78,7 +77,6 @@ library , mtl , nothunks , pretty-show - , random , QuickCheck , quickcheck-instances , tasty diff --git a/cardano-crypto-tests/src/Bench/Crypto/KES.hs b/cardano-crypto-tests/src/Bench/Crypto/KES.hs index a5126dc1d..76ae4db53 100644 --- a/cardano-crypto-tests/src/Bench/Crypto/KES.hs +++ b/cardano-crypto-tests/src/Bench/Crypto/KES.hs @@ -26,7 +26,7 @@ import Criterion import qualified Data.ByteString as BS (ByteString) import Data.Either (fromRight) import Cardano.Crypto.Libsodium as NaCl -import Cardano.Crypto.MLockedSeed +import Cardano.Crypto.Libsodium.MLockedSeed import System.IO.Unsafe (unsafePerformIO) import GHC.TypeLits (KnownNat) import Data.Kind (Type) @@ -51,7 +51,7 @@ benchmarks = bgroup "KES" {-# NOINLINE benchKES #-} benchKES :: forall (proxy :: forall k. k -> Type) v - . ( KESSignAlgorithm IO v + . ( KESSignAlgorithm v , ContextKES v ~ () , Signable v BS.ByteString , NFData (SignKeyKES v) @@ -63,21 +63,21 @@ benchKES :: forall (proxy :: forall k. k -> Type) v benchKES _ lbl = bgroup lbl [ bench "genKey" $ - nfIO $ genKeyKES @IO @v testSeedML >>= forgetSignKeyKES @IO @v + nfIO $ genKeyKES @v testSeedML >>= forgetSignKeyKES @v , bench "signKES" $ nfIO $ - (\sk -> do { sig <- signKES @IO @v () 0 typicalMsg sk; forgetSignKeyKES sk; return sig }) - =<< (genKeyKES @IO @v testSeedML) + (\sk -> do { sig <- signKES @v() 0 typicalMsg sk; forgetSignKeyKES sk; return sig }) + =<< genKeyKES @v testSeedML , bench "verifyKES" $ nfIO $ do - signKey <- genKeyKES @IO @v testSeedML - sig <- signKES @IO @v () 0 typicalMsg signKey + signKey <- genKeyKES @v testSeedML + sig <- signKES @v () 0 typicalMsg signKey verKey <- deriveVerKeyKES signKey forgetSignKeyKES signKey return . fromRight $ verifyKES @v () verKey 0 typicalMsg sig , bench "updateKES" $ nfIO $ do - signKey <- genKeyKES @IO @v testSeedML + signKey <- genKeyKES @v testSeedML sk' <- fromJust <$> updateKES () signKey 0 forgetSignKeyKES signKey return sk' diff --git a/cardano-crypto-tests/src/Cardano/Crypto/KES/ForgetMock.hs b/cardano-crypto-tests/src/Cardano/Crypto/KES/ForgetMock.hs deleted file mode 100644 index 9a0536a76..000000000 --- a/cardano-crypto-tests/src/Cardano/Crypto/KES/ForgetMock.hs +++ /dev/null @@ -1,169 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiParamTypeClasses #-} - --- | Mock key evolving signatures. -module Cardano.Crypto.KES.ForgetMock - ( ForgetMockKES - , VerKeyKES (..) - , SignKeyKES (..) - , SigKES (..) - , ForgetMockEvent (..) - , isGEN - , isUPD - , isDEL - ) -where - -import Data.Proxy (Proxy(..)) -import GHC.Generics (Generic) - -import Cardano.Crypto.KES.Class -import NoThunks.Class (NoThunks (..), allNoThunks) -import System.Random (randomRIO) -import Control.Tracer -import Test.Crypto.AllocLog -import Control.Monad.IO.Class (MonadIO) -import Control.Monad.Reader (ask) -import Control.Monad ((<$!>)) - --- | A wrapper for a KES implementation that adds logging functionality, for --- the purpose of verifying that invocations of 'genKeyKES' and --- 'forgetSignKeyKES' pair up properly in a given host application. --- --- The wrapped KES behaves exactly like its unwrapped payload, except that --- invocations of 'genKeyKES', 'updateKES' and 'forgetSignKeyKES' are logged --- as 'GenericEvent' 'ForgetMockEvent' values. (We use 'GenericEvent' in order --- to use the generic 'MonadSodium' instance of 'LogT'; otherwise we would --- have to provide a boilerplate instance here). -data ForgetMockKES k - -data ForgetMockEvent - = GEN Word - | UPD Word Word - | NOUPD - | DEL Word - deriving (Ord, Eq, Show) - -isGEN :: ForgetMockEvent -> Bool -isGEN GEN {} = True -isGEN _ = False - -isUPD :: ForgetMockEvent -> Bool -isUPD UPD {} = True -isUPD _ = False - -isDEL :: ForgetMockEvent -> Bool -isDEL DEL {} = True -isDEL _ = False - -instance - ( KESAlgorithm k - ) - => KESAlgorithm (ForgetMockKES k) where - type SeedSizeKES (ForgetMockKES k) = SeedSizeKES k - type Signable (ForgetMockKES k) = Signable k - - newtype VerKeyKES (ForgetMockKES k) = VerKeyForgetMockKES (VerKeyKES k) - deriving (Generic) - newtype SigKES (ForgetMockKES k) = SigForgetMockKES (SigKES k) - deriving (Generic) - - type ContextKES (ForgetMockKES k) = ContextKES k - - algorithmNameKES _ = algorithmNameKES (Proxy @k) - - verifyKES ctx (VerKeyForgetMockKES vk) p msg (SigForgetMockKES sig) = - verifyKES ctx vk p msg sig - - totalPeriodsKES _ = totalPeriodsKES (Proxy @k) - - type SizeVerKeyKES (ForgetMockKES k) = SizeVerKeyKES k - type SizeSignKeyKES (ForgetMockKES k) = SizeSignKeyKES k - type SizeSigKES (ForgetMockKES k) = SizeSigKES k - - rawSerialiseVerKeyKES (VerKeyForgetMockKES k) = rawSerialiseVerKeyKES k - rawSerialiseSigKES (SigForgetMockKES k) = rawSerialiseSigKES k - - rawDeserialiseVerKeyKES = fmap VerKeyForgetMockKES . rawDeserialiseVerKeyKES - rawDeserialiseSigKES = fmap SigForgetMockKES . rawDeserialiseSigKES - - -instance - ( KESSignAlgorithm (LogT (GenericEvent ForgetMockEvent) m) k - , MonadIO m - ) - => KESSignAlgorithm (LogT (GenericEvent ForgetMockEvent) m) (ForgetMockKES k) where - data SignKeyKES (ForgetMockKES k) = SignKeyForgetMockKES !Word !(SignKeyKES k) - - genKeyKES seed = do - sk <- genKeyKES seed - nonce <- randomRIO (10000000, 99999999) - tracer <- ask - traceWith tracer (GenericEvent $ GEN nonce) - return $! SignKeyForgetMockKES nonce sk - - forgetSignKeyKES (SignKeyForgetMockKES nonce sk) = do - tracer <- ask - traceWith tracer (GenericEvent $ DEL nonce) - forgetSignKeyKES sk - - deriveVerKeyKES (SignKeyForgetMockKES _ k) = - VerKeyForgetMockKES <$!> deriveVerKeyKES k - - signKES ctx p msg (SignKeyForgetMockKES _ sk) = - SigForgetMockKES <$!> signKES ctx p msg sk - - updateKES ctx (SignKeyForgetMockKES nonce sk) p = do - tracer <- ask - nonce' <- randomRIO (10000000, 99999999) - updateKES ctx sk p >>= \case - Just sk' -> do - traceWith tracer (GenericEvent $ UPD nonce nonce') - return $! Just $! SignKeyForgetMockKES nonce' sk' - Nothing -> do - traceWith tracer (GenericEvent NOUPD) - return Nothing - -instance - ( UnsoundKESSignAlgorithm (LogT (GenericEvent ForgetMockEvent) m) k - , MonadIO m - ) - => UnsoundKESSignAlgorithm (LogT (GenericEvent ForgetMockEvent) m) (ForgetMockKES k) where - - rawSerialiseSignKeyKES (SignKeyForgetMockKES _ k) = rawSerialiseSignKeyKES k - - rawDeserialiseSignKeyKES bs = do - msk <- rawDeserialiseSignKeyKES bs - nonce :: Word <- randomRIO (10000000, 99999999) - return $ fmap (SignKeyForgetMockKES nonce) msk - - -deriving instance Show (VerKeyKES k) => Show (VerKeyKES (ForgetMockKES k)) -deriving instance Eq (VerKeyKES k) => Eq (VerKeyKES (ForgetMockKES k)) -deriving instance Ord (VerKeyKES k) => Ord (VerKeyKES (ForgetMockKES k)) -deriving instance NoThunks (VerKeyKES k) => NoThunks (VerKeyKES (ForgetMockKES k)) - -deriving instance Eq (SignKeyKES k) => Eq (SignKeyKES (ForgetMockKES k)) - -instance NoThunks (SignKeyKES k) => NoThunks (SignKeyKES (ForgetMockKES k)) where - showTypeOf _ = "SignKeyKES (ForgetMockKES k)" - wNoThunks ctx (SignKeyForgetMockKES t k) = - allNoThunks - [ noThunks ctx t - , noThunks ctx k - ] - -deriving instance Show (SigKES k) => Show (SigKES (ForgetMockKES k)) -deriving instance Eq (SigKES k) => Eq (SigKES (ForgetMockKES k)) -deriving instance Ord (SigKES k) => Ord (SigKES (ForgetMockKES k)) -deriving instance NoThunks (SigKES k) => NoThunks (SigKES (ForgetMockKES k)) diff --git a/cardano-crypto-tests/src/Test/Crypto/AllocLog.hs b/cardano-crypto-tests/src/Test/Crypto/AllocLog.hs index 7c2c755df..8e9afd478 100644 --- a/cardano-crypto-tests/src/Test/Crypto/AllocLog.hs +++ b/cardano-crypto-tests/src/Test/Crypto/AllocLog.hs @@ -1,23 +1,17 @@ -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# OPTIONS_GHC -Wno-deprecations #-} module Test.Crypto.AllocLog where -import Cardano.Crypto.MonadSodium -import Cardano.Crypto.Libsodium.Memory.Internal (MLockedForeignPtr (..)) import Control.Tracer -import Control.Monad.Reader -import Foreign.Ptr -import Control.Monad.Class.MonadThrow -import Control.Monad.Class.MonadST -import Control.Monad.ST.Unsafe (unsafeIOToST) import Data.Typeable -import Data.Coerce (coerce) -import Foreign.Concurrent (addForeignPtrFinalizer) -import Test.Crypto.RunIO +import Foreign.Ptr +import Foreign.Concurrent + +import Cardano.Crypto.Libsodium (withMLockedForeignPtr) +import Cardano.Crypto.Libsodium.Memory (MLockedAllocator(..)) +import Cardano.Crypto.Libsodium.Memory.Internal (MLockedForeignPtr (..)) -- | Allocation log event. These are emitted automatically whenever mlocked -- memory is allocated through the 'mlockedAllocForeignPtr' primitive, or @@ -28,74 +22,17 @@ data AllocEvent = AllocEv !WordPtr | FreeEv !WordPtr | MarkerEv !String - deriving (Eq, Show) - -newtype LogT event m a = LogT { unLogT :: ReaderT (Tracer (LogT event m) event) m a } - deriving (Functor, Applicative, Monad, MonadThrow, MonadST, Typeable, MonadIO) - -type AllocLogT = LogT AllocEvent - -instance Monad m => MonadReader (Tracer (LogT event m) event) (LogT event m) where - ask = LogT ask - local f (LogT action) = LogT (local f action) - -instance MonadTrans (LogT event) where - lift action = LogT (lift action) - -runLogT :: Tracer (LogT event m) event -> LogT event m a -> m a -runLogT tracer action = runReaderT (unLogT action) tracer - -runAllocLogT :: Tracer (LogT AllocEvent m) AllocEvent -> LogT AllocEvent m a -> m a -runAllocLogT = runLogT - -pushLogEvent :: Monad m => event -> LogT event m () -pushLogEvent event = do - tracer <- ask - traceWith tracer event - -pushAllocLogEvent :: Monad m => AllocEvent -> LogT AllocEvent m () -pushAllocLogEvent = pushLogEvent - --- | Automatically log all mlocked allocation events (allocate and free) via --- 'mlockedAlloca', 'mlockedMalloc', and associated finalizers. -instance (MonadIO m, MonadThrow m, MonadSodium m, MonadST m, RunIO m) - => MonadSodium (LogT AllocEvent m) where - withMLockedForeignPtr fptr action = LogT $ do - tracer <- ask - lift $ withMLockedForeignPtr fptr (\ptr -> (runReaderT . unLogT) (action ptr) tracer) - - finalizeMLockedForeignPtr = lift . finalizeMLockedForeignPtr - - traceMLockedForeignPtr = lift . traceMLockedForeignPtr - - mlockedMalloc size = do - fptr <- lift (mlockedMalloc size) - addr <- withMLockedForeignPtr fptr (return . ptrToWordPtr) - pushAllocLogEvent (AllocEv addr) - tracer :: Tracer (LogT event m) event <- ask - withLiftST $ \liftST -> liftST . unsafeIOToST $ - addForeignPtrFinalizer - (coerce fptr) - (io . runLogT tracer . pushAllocLogEvent $ FreeEv addr) - return fptr - - zeroMem addr size = lift $ zeroMem addr size - copyMem dst src size = lift $ copyMem dst src size - --- | Newtype wrapper over an arbitrary event; we use this to write the generic --- 'MonadSodium' instance below while avoiding overlapping instances. -newtype GenericEvent e = GenericEvent { concreteEvent :: e } - --- | Generic instance, log nothing automatically. Log entries can be triggered --- manually using 'pushLogEvent'. -instance MonadSodium m => MonadSodium (LogT (GenericEvent e) m) where - withMLockedForeignPtr fptr (action) = LogT $ do - tracer <- ask - lift $ withMLockedForeignPtr fptr (\ptr -> (runReaderT . unLogT) (action ptr) tracer) - - finalizeMLockedForeignPtr = lift . finalizeMLockedForeignPtr - traceMLockedForeignPtr = lift . traceMLockedForeignPtr - mlockedMalloc size = lift (mlockedMalloc size) - - zeroMem addr size = lift $ zeroMem addr size - copyMem dst src size = lift $ copyMem dst src size + deriving (Eq, Show, Typeable) + +mkLoggingAllocator :: + Tracer IO AllocEvent -> MLockedAllocator IO -> MLockedAllocator IO +mkLoggingAllocator tracer ioAllocator = + MLockedAllocator + { mlAllocate = + \size -> do + sfptr@(SFP fptr) <- mlAllocate ioAllocator size + addr <- withMLockedForeignPtr sfptr (return . ptrToWordPtr) + traceWith tracer (AllocEv addr) + addForeignPtrFinalizer fptr (traceWith tracer (FreeEv addr)) + return sfptr + } diff --git a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs index 3de3539c9..fcb969906 100644 --- a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs +++ b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs @@ -32,8 +32,8 @@ import Test.Tasty (TestTree, testGroup, adjustOption) import Test.Tasty.QuickCheck (testProperty, QuickCheckTests) import qualified Data.ByteString as BS -import qualified Cardano.Crypto.Libsodium as NaCl -import Cardano.Crypto.MonadSodium (MEq (..), (==!)) +import Cardano.Crypto.Libsodium +import Cardano.Crypto.EqST (EqST (..), (==!)) import Text.Show.Pretty (ppShow) @@ -93,9 +93,9 @@ import Cardano.Crypto.DSIGN ( rawSerialiseSigDSIGNM, rawDeserialiseSigDSIGNM), DSIGNMAlgorithm (), - UnsoundDSIGNMAlgorithm ( - rawSerialiseSignKeyDSIGNM, - rawDeserialiseSignKeyDSIGNM), + UnsoundDSIGNMAlgorithm, + rawSerialiseSignKeyDSIGNM, + rawDeserialiseSignKeyDSIGNM, sizeVerKeyDSIGNM, sizeSignKeyDSIGNM, sizeSigDSIGNM, @@ -134,7 +134,7 @@ import Test.Crypto.Util ( withLock, ) import Test.Crypto.Instances (withMLockedSeedFromPSB) -import Cardano.Crypto.MLockedSeed +import Cardano.Crypto.Libsodium.MLockedSeed #ifdef SECP256K1_ENABLED import Cardano.Crypto.DSIGN ( @@ -363,7 +363,7 @@ testDSIGNAlgorithm genSig genMsg name = adjustOption testEnough . testGroup name testDSIGNMAlgorithm :: forall v. ( -- change back to DSIGNMAlgorithm when unsound API is phased out - UnsoundDSIGNMAlgorithm IO v + UnsoundDSIGNMAlgorithm v , ToCBOR (VerKeyDSIGNM v) , FromCBOR (VerKeyDSIGNM v) -- DSIGNM cannot satisfy To/FromCBOR (not even with @@ -372,7 +372,7 @@ testDSIGNMAlgorithm -- test direct encoding/decoding for 'SignKeyDSIGNM'. -- , ToCBOR (SignKeyDSIGNM v) -- , FromCBOR (SignKeyDSIGNM v) - , MEq IO (SignKeyDSIGNM v) -- only monadic MEq for signing keys + , EqST (SignKeyDSIGNM v) -- only monadic EqST for signing keys , ToCBOR (SigDSIGNM v) , FromCBOR (SigDSIGNM v) , ContextDSIGNM v ~ () @@ -406,15 +406,15 @@ testDSIGNMAlgorithm lock _ n = [ testProperty "VerKey" $ ioPropertyWithSK @v lock $ \sk -> do vk <- deriveVerKeyDSIGNM sk - return $ (fromIntegral . BS.length . rawSerialiseVerKeyDSIGNM $ vk) === (sizeVerKeyDSIGNM (Proxy @v)) + return $ (fromIntegral . BS.length . rawSerialiseVerKeyDSIGNM $ vk) === sizeVerKeyDSIGNM (Proxy @v) , testProperty "SignKey" $ ioPropertyWithSK @v lock $ \sk -> do serialized <- rawSerialiseSignKeyDSIGNM sk - evaluate ((fromIntegral . BS.length $ serialized) == (sizeSignKeyDSIGNM (Proxy @v))) + evaluate ((fromIntegral . BS.length $ serialized) == sizeSignKeyDSIGNM (Proxy @v)) , testProperty "Sig" $ \(msg :: Message) -> ioPropertyWithSK @v lock $ \sk -> do sig :: SigDSIGNM v <- signDSIGNM () msg sk - return $ (fromIntegral . BS.length . rawSerialiseSigDSIGNM $ sig) === (sizeSigDSIGNM (Proxy @v)) + return $ (fromIntegral . BS.length . rawSerialiseSigDSIGNM $ sig) === sizeSigDSIGNM (Proxy @v) ] , testGroup "direct CBOR" @@ -432,37 +432,37 @@ testDSIGNMAlgorithm lock _ n = , testGroup "To/FromCBOR class" [ testProperty "VerKey" $ - ioPropertyWithSK lock $ \sk -> do + ioPropertyWithSK @v lock $ \sk -> do vk :: VerKeyDSIGNM v <- deriveVerKeyDSIGNM sk return $ prop_cbor vk -- No To/FromCBOR for 'SignKeyDSIGNM', see above. , testProperty "Sig" $ \(msg :: Message) -> - ioPropertyWithSK lock $ \sk -> do + ioPropertyWithSK @v lock $ \sk -> do sig :: SigDSIGNM v <- signDSIGNM () msg sk return $ prop_cbor sig ] , testGroup "ToCBOR size" [ testProperty "VerKey" $ - ioPropertyWithSK lock $ \sk -> do + ioPropertyWithSK @v lock $ \sk -> do vk :: VerKeyDSIGNM v <- deriveVerKeyDSIGNM sk return $ prop_cbor_size vk -- No To/FromCBOR for 'SignKeyDSIGNM', see above. , testProperty "Sig" $ \(msg :: Message) -> - ioPropertyWithSK lock $ \sk -> do + ioPropertyWithSK @v lock $ \sk -> do sig :: SigDSIGNM v <- signDSIGNM () msg sk return $ prop_cbor_size sig ] , testGroup "direct matches class" [ testProperty "VerKey" $ - ioPropertyWithSK lock $ \sk -> do + ioPropertyWithSK @v lock $ \sk -> do vk :: VerKeyDSIGNM v <- deriveVerKeyDSIGNM sk return $ prop_cbor_direct_vs_class encodeVerKeyDSIGNM vk -- No CBOR testing for SignKey: sign keys are stored in MLocked memory -- and require IO for access. , testProperty "Sig" $ \(msg :: Message) -> - ioPropertyWithSK lock $ \sk -> do + ioPropertyWithSK @v lock $ \sk -> do sig :: SigDSIGNM v <- signDSIGNM () msg sk return $ prop_cbor_direct_vs_class encodeSigDSIGNM sig ] @@ -500,7 +500,7 @@ testDSIGNMAlgorithm lock _ n = -- timely forgetting. Special care must be taken to not leak the key outside of -- the wrapped action (be particularly mindful of thunks and unsafe key access -- here). -withSK :: (DSIGNMAlgorithm IO v) => PinnedSizedBytes (SeedSizeDSIGNM v) -> (SignKeyDSIGNM v -> IO b) -> IO b +withSK :: (DSIGNMAlgorithm v) => PinnedSizedBytes (SeedSizeDSIGNM v) -> (SignKeyDSIGNM v -> IO b) -> IO b withSK seedPSB action = withMLockedSeedFromPSB seedPSB $ \seed -> bracket @@ -515,7 +515,7 @@ withSK seedPSB action = -- memory. Special care must be taken to not leak the key outside of the -- wrapped action (be particularly mindful of thunks and unsafe key access -- here). -ioPropertyWithSK :: forall v a. (Testable a, DSIGNMAlgorithm IO v) +ioPropertyWithSK :: forall v a. (Testable a, DSIGNMAlgorithm v) => Lock -> (SignKeyDSIGNM v -> IO a) -> PinnedSizedBytes (SeedSizeDSIGNM v) @@ -525,7 +525,7 @@ ioPropertyWithSK lock action seedPSB = prop_key_overwritten_after_forget :: forall v. - (DSIGNMAlgorithm IO v + (DSIGNMAlgorithm v ) => Proxy v -> PinnedSizedBytes (SeedSizeDSIGNM v) @@ -536,20 +536,20 @@ prop_key_overwritten_after_forget p seedPSB = mlockedSeedFinalize seed seedBefore <- getSeedDSIGNM p sk - bsBefore <- NaCl.mlsbToByteString . mlockedSeedMLSB $ seedBefore + bsBefore <- mlsbToByteString . mlockedSeedMLSB $ seedBefore mlockedSeedFinalize seedBefore forgetSignKeyDSIGNM sk seedAfter <- getSeedDSIGNM p sk - bsAfter <- NaCl.mlsbToByteString . mlockedSeedMLSB $ seedAfter + bsAfter <- mlsbToByteString . mlockedSeedMLSB $ seedAfter mlockedSeedFinalize seedAfter return (bsBefore =/= bsAfter) prop_dsignm_seed_roundtrip :: forall v. - ( DSIGNMAlgorithm IO v + ( DSIGNMAlgorithm v ) => Proxy v -> PinnedSizedBytes (SeedSizeDSIGNM v) @@ -557,8 +557,8 @@ prop_dsignm_seed_roundtrip prop_dsignm_seed_roundtrip p seedPSB = ioProperty . withMLockedSeedFromPSB seedPSB $ \seed -> do sk <- genKeyDSIGNM seed seed' <- getSeedDSIGNM p sk - bs <- NaCl.mlsbToByteString . mlockedSeedMLSB $ seed - bs' <- NaCl.mlsbToByteString . mlockedSeedMLSB $ seed' + bs <- mlsbToByteString . mlockedSeedMLSB $ seed + bs' <- mlsbToByteString . mlockedSeedMLSB $ seed' forgetSignKeyDSIGNM sk mlockedSeedFinalize seed' return (bs === bs') @@ -594,7 +594,7 @@ prop_dsign_verify_wrong_key (msg, sk, sk') = in verifyDSIGN () vk' msg signed =/= Right () prop_dsignm_verify_pos - :: forall v. (DSIGNMAlgorithm IO v, ContextDSIGNM v ~ (), SignableM v Message) + :: forall v. (DSIGNMAlgorithm v, ContextDSIGNM v ~ (), SignableM v Message) => Lock -> Proxy v -> Message @@ -611,7 +611,7 @@ prop_dsignm_verify_pos lock _ msg = -- different signing key, then the verification fails. -- prop_dsignm_verify_neg_key - :: forall v. (DSIGNMAlgorithm IO v, ContextDSIGNM v ~ (), SignableM v Message) + :: forall v. (DSIGNMAlgorithm v, ContextDSIGNM v ~ (), SignableM v Message) => Lock -> Proxy v -> Message @@ -681,7 +681,7 @@ testEcdsaWithHashAlgorithm _ name = adjustOption defaultTestEnough . testGroup n #endif prop_dsignm_verify_neg_msg - :: forall v. (DSIGNMAlgorithm IO v, ContextDSIGNM v ~ (), SignableM v Message) + :: forall v. (DSIGNMAlgorithm v, ContextDSIGNM v ~ (), SignableM v Message) => Lock -> Proxy v -> Message diff --git a/cardano-crypto-tests/src/Test/Crypto/Instances.hs b/cardano-crypto-tests/src/Test/Crypto/Instances.hs index f295d99ee..7119b4c0e 100644 --- a/cardano-crypto-tests/src/Test/Crypto/Instances.hs +++ b/cardano-crypto-tests/src/Test/Crypto/Instances.hs @@ -12,13 +12,9 @@ import Data.Proxy (Proxy (Proxy)) import GHC.TypeLits (KnownNat, natVal) import Test.QuickCheck (Arbitrary (..)) import qualified Test.QuickCheck.Gen as Gen -import Cardano.Crypto.MonadSodium -import Cardano.Crypto.MLockedSeed -import Cardano.Crypto.PinnedSizedBytes ( - PinnedSizedBytes, - psbFromByteStringCheck, - psbToByteString, - ) +import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.PinnedSizedBytes import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadST @@ -37,19 +33,19 @@ import Control.Monad.Class.MonadST -- size :: Int -- size = fromInteger (natVal (Proxy :: Proxy n)) -mlsbFromPSB :: (MonadSodium m, MonadST m, KnownNat n) => PinnedSizedBytes n -> m (MLockedSizedBytes n) +mlsbFromPSB :: (MonadST m, KnownNat n) => PinnedSizedBytes n -> m (MLockedSizedBytes n) mlsbFromPSB = mlsbFromByteString . psbToByteString -withMLSBFromPSB :: (MonadSodium m, MonadST m, MonadThrow m, KnownNat n) => PinnedSizedBytes n -> (MLockedSizedBytes n -> m a) -> m a +withMLSBFromPSB :: (MonadST m, MonadThrow m, KnownNat n) => PinnedSizedBytes n -> (MLockedSizedBytes n -> m a) -> m a withMLSBFromPSB psb = bracket (mlsbFromPSB psb) mlsbFinalize -mlockedSeedFromPSB :: (MonadSodium m, MonadST m, KnownNat n) => PinnedSizedBytes n -> m (MLockedSeed n) +mlockedSeedFromPSB :: (MonadST m, KnownNat n) => PinnedSizedBytes n -> m (MLockedSeed n) mlockedSeedFromPSB = fmap MLockedSeed . mlsbFromPSB -withMLockedSeedFromPSB :: (MonadSodium m, MonadST m, MonadThrow m, KnownNat n) => PinnedSizedBytes n -> (MLockedSeed n -> m a) -> m a +withMLockedSeedFromPSB :: (MonadST m, MonadThrow m, KnownNat n) => PinnedSizedBytes n -> (MLockedSeed n -> m a) -> m a withMLockedSeedFromPSB psb = bracket (mlockedSeedFromPSB psb) diff --git a/cardano-crypto-tests/src/Test/Crypto/KES.hs b/cardano-crypto-tests/src/Test/Crypto/KES.hs index efad9bdca..4212f3188 100644 --- a/cardano-crypto-tests/src/Test/Crypto/KES.hs +++ b/cardano-crypto-tests/src/Test/Crypto/KES.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} @@ -30,24 +31,21 @@ import Data.Set (Set) import qualified Data.Set as Set import Foreign.Ptr (WordPtr) import Data.IORef -import Data.Foldable (traverse_) import GHC.TypeNats (KnownNat) import Control.Tracer -import Control.Monad (void) -import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Class.MonadThrow -import Control.Monad.Class.MonadST +import Control.Monad.IO.Class (liftIO) +import Control.Monad (void) import Cardano.Crypto.DSIGN hiding (Signable) import Cardano.Crypto.Hash import Cardano.Crypto.KES -import Cardano.Crypto.KES.ForgetMock import Cardano.Crypto.Util (SignableRepresentation(..)) -import Cardano.Crypto.MLockedSeed -import qualified Cardano.Crypto.Libsodium as NaCl -import Cardano.Crypto.PinnedSizedBytes (PinnedSizedBytes) -import Cardano.Crypto.MonadSodium +import Cardano.Crypto.Libsodium +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.PinnedSizedBytes +import Cardano.Crypto.EqST import Test.QuickCheck import Test.Tasty (TestTree, testGroup, adjustOption) @@ -71,11 +69,11 @@ import Test.Crypto.Util ( Lock, withLock, ) -import Test.Crypto.RunIO (RunIO (..)) import Test.Crypto.Instances (withMLockedSeedFromPSB) import Test.Crypto.AllocLog {- HLINT ignore "Reduce duplication" -} +{- HLINT ignore "Use head" -} -- -- The list of all tests @@ -83,18 +81,18 @@ import Test.Crypto.AllocLog tests :: Lock -> TestTree tests lock = testGroup "Crypto.KES" - [ testKESAlloc (Proxy :: Proxy (SingleKES Ed25519DSIGNM)) "SingleKES" - , testKESAlloc (Proxy :: Proxy (Sum1KES Ed25519DSIGNM Blake2b_256)) "Sum1KES" - , testKESAlloc (Proxy :: Proxy (Sum2KES Ed25519DSIGNM Blake2b_256)) "Sum2KES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (MockKES 7)) "MockKES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (SimpleKES Ed25519DSIGNM 7)) "SimpleKES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (SingleKES Ed25519DSIGNM)) "SingleKES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (Sum1KES Ed25519DSIGNM Blake2b_256)) "Sum1KES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (Sum2KES Ed25519DSIGNM Blake2b_256)) "Sum2KES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (Sum5KES Ed25519DSIGNM Blake2b_256)) "Sum5KES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (CompactSum1KES Ed25519DSIGNM Blake2b_256)) "CompactSum1KES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (CompactSum2KES Ed25519DSIGNM Blake2b_256)) "CompactSum2KES" - , testKESAlgorithm lock (Proxy :: Proxy IO) (Proxy :: Proxy (CompactSum5KES Ed25519DSIGNM Blake2b_256)) "CompactSum5KES" + [ testKESAlloc (Proxy @(SingleKES Ed25519DSIGNM)) "SingleKES" + , testKESAlloc (Proxy @(Sum1KES Ed25519DSIGNM Blake2b_256)) "Sum1KES" + , testKESAlloc (Proxy @(Sum2KES Ed25519DSIGNM Blake2b_256)) "Sum2KES" + , testKESAlgorithm @(MockKES 7) lock "MockKES" + , testKESAlgorithm @(SimpleKES Ed25519DSIGNM 7) lock "SimpleKES" + , testKESAlgorithm @(SingleKES Ed25519DSIGNM) lock "SingleKES" + , testKESAlgorithm @(Sum1KES Ed25519DSIGNM Blake2b_256) lock "Sum1KES" + , testKESAlgorithm @(Sum2KES Ed25519DSIGNM Blake2b_256) lock "Sum2KES" + , testKESAlgorithm @(Sum5KES Ed25519DSIGNM Blake2b_256) lock "Sum5KES" + , testKESAlgorithm @(CompactSum1KES Ed25519DSIGNM Blake2b_256) lock "CompactSum1KES" + , testKESAlgorithm @(CompactSum2KES Ed25519DSIGNM Blake2b_256) lock "CompactSum2KES" + , testKESAlgorithm @(CompactSum5KES Ed25519DSIGNM Blake2b_256) lock "CompactSum5KES" ] -- We normally ensure that we avoid naively comparing signing keys by not @@ -103,7 +101,7 @@ tests lock = instance Show (SignKeyKES (SingleKES Ed25519DSIGNM)) where show (SignKeySingleKES (SignKeyEd25519DSIGNM mlsb)) = - let bytes = NaCl.mlsbAsByteString mlsb + let bytes = mlsbAsByteString mlsb hexstr = hexBS bytes in "SignKeySingleKES (SignKeyEd25519DSIGNM " ++ hexstr ++ ")" @@ -112,109 +110,50 @@ instance Show (SignKeyKES (SumKES h d)) where instance Show (SignKeyKES (CompactSingleKES Ed25519DSIGNM)) where show (SignKeyCompactSingleKES (SignKeyEd25519DSIGNM mlsb)) = - let bytes = NaCl.mlsbAsByteString mlsb + let bytes = mlsbAsByteString mlsb hexstr = hexBS bytes in "SignKeyCompactSingleKES (SignKeyEd25519DSIGNM " ++ hexstr ++ ")" instance Show (SignKeyKES (CompactSumKES h d)) where show _ = "" -deriving via (PureMEq (SignKeyKES (MockKES t))) instance Applicative m => MEq m (SignKeyKES (MockKES t)) +deriving via (PureEqST (SignKeyKES (MockKES t))) instance EqST (SignKeyKES (MockKES t)) -deriving newtype instance (MEq m (SignKeyDSIGNM d)) => MEq m (SignKeyKES (SingleKES d)) +deriving newtype instance (EqST (SignKeyDSIGNM d)) => EqST (SignKeyKES (SingleKES d)) -instance ( MonadSodium m - , MonadST m - , MEq m (SignKeyKES d) +instance ( EqST (SignKeyKES d) , Eq (VerKeyKES d) , KnownNat (SeedSizeKES d) - ) => MEq m (SignKeyKES (SumKES h d)) where + ) => EqST (SignKeyKES (SumKES h d)) where equalsM (SignKeySumKES s r v1 v2) (SignKeySumKES s' r' v1' v2') = - (s, r, PureMEq v1, PureMEq v2) ==! (s', r', PureMEq v1', PureMEq v2') + (s, r, PureEqST v1, PureEqST v2) ==! (s', r', PureEqST v1', PureEqST v2') -deriving newtype instance (MEq m (SignKeyDSIGNM d)) => MEq m (SignKeyKES (CompactSingleKES d)) +deriving newtype instance (EqST (SignKeyDSIGNM d)) => EqST (SignKeyKES (CompactSingleKES d)) -instance ( MonadSodium m - , MonadST m - , MEq m (SignKeyKES d) +instance ( EqST (SignKeyKES d) , Eq (VerKeyKES d) , KnownNat (SeedSizeKES d) - ) => MEq m (SignKeyKES (CompactSumKES h d)) where + ) => EqST (SignKeyKES (CompactSumKES h d)) where equalsM (SignKeyCompactSumKES s r v1 v2) (SignKeyCompactSumKES s' r' v1' v2') = - (s, r, PureMEq v1, PureMEq v2) ==! (s', r', PureMEq v1', PureMEq v2') + (s, r, PureEqST v1, PureEqST v2) ==! (s', r', PureEqST v1', PureEqST v2') testKESAlloc :: forall v. - ( (forall m. (MonadSodium m, MonadThrow m, MonadST m) => KESSignAlgorithm m v) - , ContextKES v ~ () + ( KESSignAlgorithm v ) => Proxy v -> String -> TestTree testKESAlloc _p n = testGroup n - [ testGroup "Forget mock" - [ testCase "genKey" $ testForgetGenKeyKES _p - , testCase "updateKey" $ testForgetUpdateKeyKES _p - ] - , testGroup "Low-level mlocked allocations" + [ testGroup "Low-level mlocked allocations" [ testCase "genKey" $ testMLockGenKeyKES _p -- , testCase "updateKey" $ testMLockUpdateKeyKES _p ] ] -testForgetGenKeyKES - :: forall v. - ( KESSignAlgorithm (LogT (GenericEvent ForgetMockEvent) IO) v - ) - => Proxy v - -> Assertion -testForgetGenKeyKES _p = do - logVar <- newIORef [] - let tracer :: Tracer (LogT (GenericEvent ForgetMockEvent) IO) (GenericEvent ForgetMockEvent) - tracer = Tracer (\ev -> liftIO $ modifyIORef logVar (++ [ev])) - runLogT tracer $ do - seed <- MLockedSeed <$> mlsbFromByteString (BS.replicate 1024 23) - sk <- genKeyKES @(LogT (GenericEvent ForgetMockEvent) IO) @(ForgetMockKES v) seed - mlockedSeedFinalize seed - forgetSignKeyKES sk - result <- map concreteEvent <$> readIORef logVar - assertBool ("Unexpected log: " ++ show result) $ case result of - [GEN a, DEL b] -> - -- End of last period, so no update happened - a == b - _ -> False - return () - -testForgetUpdateKeyKES - :: forall v. - ( KESSignAlgorithm (LogT (GenericEvent ForgetMockEvent) IO) v - , ContextKES v ~ () - ) - => Proxy v - -> Assertion -testForgetUpdateKeyKES _p = do - logVar <- newIORef [] - let tracer :: Tracer (LogT (GenericEvent ForgetMockEvent) IO) (GenericEvent ForgetMockEvent) - tracer = Tracer (\ev -> liftIO $ modifyIORef logVar (++ [ev])) - runLogT tracer $ do - seed <- MLockedSeed <$> NaCl.mlsbFromByteString (BS.replicate 1024 23) - sk <- genKeyKES @(LogT (GenericEvent ForgetMockEvent) IO) @(ForgetMockKES v) seed - mlockedSeedFinalize seed - msk' <- updateKES () sk 0 - forgetSignKeyKES sk - traverse_ forgetSignKeyKES msk' - result <- map concreteEvent <$> readIORef logVar - - assertBool ("Unexpected log: " ++ show result) $ case result of - [GEN a, UPD b c, DEL d, DEL e] -> - -- Regular update - a == b && d == a && e == c - [GEN a, NOUPD, DEL b] -> - -- End of last period, so no update happened - a == b - _ -> False - +eventTracer :: IORef [event] -> Tracer IO event +eventTracer logVar = Tracer (\ev -> liftIO $ atomicModifyIORef' logVar (\acc -> (acc ++ [ev], ()))) matchAllocLog :: [AllocEvent] -> Set WordPtr matchAllocLog = foldl' (flip go) Set.empty @@ -225,54 +164,51 @@ matchAllocLog = foldl' (flip go) Set.empty testMLockGenKeyKES :: forall v. - ( KESSignAlgorithm (AllocLogT IO) v - ) + KESSignAlgorithm v => Proxy v -> Assertion testMLockGenKeyKES _p = do accumVar <- newIORef [] - let tracer = Tracer (\ev -> liftIO $ modifyIORef accumVar (++ [ev])) - runAllocLogT tracer $ do - pushAllocLogEvent $ MarkerEv "gen seed" - (seed :: MLockedSeed (SeedSizeKES v)) <- MLockedSeed <$> NaCl.mlsbFromByteString (BS.replicate 1024 23) - pushAllocLogEvent $ MarkerEv "gen key" - sk <- genKeyKES @_ @v seed - pushAllocLogEvent $ MarkerEv "forget key" - forgetSignKeyKES sk - pushAllocLogEvent $ MarkerEv "forget seed" - mlockedSeedFinalize seed - pushAllocLogEvent $ MarkerEv "done" + let tracer = eventTracer accumVar + let allocator = mkLoggingAllocator tracer mlockedMalloc + traceWith tracer $ MarkerEv "gen seed" + seed :: MLockedSeed (SeedSizeKES v) <- + MLockedSeed <$> mlsbFromByteStringWith allocator (BS.replicate 1024 23) + traceWith tracer $ MarkerEv "gen key" + sk <- genKeyKESWith @v allocator seed + traceWith tracer $ MarkerEv "forget key" + forgetSignKeyKESWith allocator sk + traceWith tracer $ MarkerEv "forget seed" + mlockedSeedFinalize seed + traceWith tracer $ MarkerEv "done" after <- readIORef accumVar let evset = matchAllocLog after + assertBool "some allocations happened" (not . null $ [ () | AllocEv _ <- after ]) assertEqual "all allocations deallocated" Set.empty evset {-# NOINLINE testKESAlgorithm#-} testKESAlgorithm - :: forall m v. + :: forall v. ( ToCBOR (VerKeyKES v) , FromCBOR (VerKeyKES v) - , MEq IO (SignKeyKES v) -- only monadic MEq for signing keys + , EqST (SignKeyKES v) -- only monadic EqST for signing keys , Show (SignKeyKES v) -- fake instance defined locally , ToCBOR (SigKES v) , FromCBOR (SigKES v) , Signable v ~ SignableRepresentation , ContextKES v ~ () - , KESSignAlgorithm m v - -- , KESSignAlgorithm IO v -- redundant for now - , UnsoundKESSignAlgorithm IO v + , UnsoundKESSignAlgorithm v ) => Lock - -> Proxy m - -> Proxy v -> String -> TestTree -testKESAlgorithm lock _pm _pv n = +testKESAlgorithm lock n = testGroup n - [ testProperty "only gen signkey" $ prop_onlyGenSignKeyKES @v lock Proxy - , testProperty "only gen verkey" $ prop_onlyGenVerKeyKES @v lock Proxy - , testProperty "one update signkey" $ prop_oneUpdateSignKeyKES lock (Proxy @IO) (Proxy @v) - , testProperty "all updates signkey" $ prop_allUpdatesSignKeyKES lock (Proxy @IO) (Proxy @v) - , testProperty "total periods" $ prop_totalPeriodsKES lock (Proxy @IO) (Proxy @v) + [ testProperty "only gen signkey" $ prop_onlyGenSignKeyKES @v lock + , testProperty "only gen verkey" $ prop_onlyGenVerKeyKES @v lock + , testProperty "one update signkey" $ prop_oneUpdateSignKeyKES @v lock + , testProperty "all updates signkey" $ prop_allUpdatesSignKeyKES @v lock + , testProperty "total periods" $ prop_totalPeriodsKES @v lock , testGroup "NoThunks" [ testProperty "VerKey" $ ioPropertyWithSK @v lock $ \sk -> @@ -287,11 +223,11 @@ testKESAlgorithm lock _pm _pv n = (maybe (return ()) forgetSignKeyKES) (prop_no_thunks_IO . return) , testProperty "Sig" $ \seedPSB (msg :: Message) -> - ioProperty $ withLock lock $ fmap conjoin $ withAllUpdatesKES @IO @v seedPSB $ \t sk -> do + ioProperty $ withLock lock $ fmap conjoin $ withAllUpdatesKES @v seedPSB $ \t sk -> do prop_no_thunks_IO (signKES () t msg sk) ] - , testProperty "same VerKey " $ prop_deriveVerKeyKES (Proxy @IO) (Proxy @v) + , testProperty "same VerKey " $ prop_deriveVerKeyKES @v , testGroup "serialisation" [ testGroup "raw ser only" @@ -380,16 +316,16 @@ testKESAlgorithm lock _pm _pv n = ] , testGroup "verify" - [ testProperty "positive" $ prop_verifyKES_positive @IO @v Proxy Proxy - , testProperty "negative (key)" $ prop_verifyKES_negative_key @IO @v Proxy Proxy - , testProperty "negative (message)" $ prop_verifyKES_negative_message @IO @v Proxy Proxy + [ testProperty "positive" $ prop_verifyKES_positive @v + , testProperty "negative (key)" $ prop_verifyKES_negative_key @v + , testProperty "negative (message)" $ prop_verifyKES_negative_message @v , adjustOption (\(QuickCheckMaxSize sz) -> QuickCheckMaxSize (min sz 50)) $ - testProperty "negative (period)" $ prop_verifyKES_negative_period @IO @v Proxy Proxy + testProperty "negative (period)" $ prop_verifyKES_negative_period @v ] , testGroup "serialisation of all KES evolutions" - [ testProperty "VerKey" $ prop_serialise_VerKeyKES @IO @v Proxy Proxy - , testProperty "Sig" $ prop_serialise_SigKES @IO @v Proxy Proxy + [ testProperty "VerKey" $ prop_serialise_VerKeyKES @v + , testProperty "Sig" $ prop_serialise_SigKES @v ] -- TODO: this doesn't pass right now, see @@ -406,11 +342,8 @@ testKESAlgorithm lock _pm _pv n = -- timely forgetting. Special care must be taken to not leak the key outside of -- the wrapped action (be particularly mindful of thunks and unsafe key access -- here). -withSK :: ( MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v - ) => PinnedSizedBytes (SeedSizeKES v) -> (SignKeyKES v -> m b) -> m b +withSK :: KESSignAlgorithm v + => PinnedSizedBytes (SeedSizeKES v) -> (SignKeyKES v -> IO b) -> IO b withSK seedPSB = bracket (withMLockedSeedFromPSB seedPSB genKeyKES) @@ -423,7 +356,7 @@ withSK seedPSB = -- memory. Special care must be taken to not leak the key outside of the -- wrapped action (be particularly mindful of thunks and unsafe key access -- here). -ioPropertyWithSK :: forall v a. (Testable a, KESSignAlgorithm IO v) +ioPropertyWithSK :: forall v a. (Testable a, KESSignAlgorithm v) => Lock -> (SignKeyKES v -> IO a) -> PinnedSizedBytes (SeedSizeKES v) @@ -456,70 +389,55 @@ ioPropertyWithSK lock action seedPSB = prop_onlyGenSignKeyKES :: forall v. - KESSignAlgorithm IO v - => Lock -> Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property -prop_onlyGenSignKeyKES lock _ = + KESSignAlgorithm v + => Lock -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_onlyGenSignKeyKES lock = ioPropertyWithSK @v lock $ const noExceptionsThrown prop_onlyGenVerKeyKES :: forall v. - KESSignAlgorithm IO v - => Lock -> Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property -prop_onlyGenVerKeyKES lock _ = + KESSignAlgorithm v + => Lock -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_onlyGenVerKeyKES lock = ioPropertyWithSK @v lock $ doesNotThrow . deriveVerKeyKES prop_oneUpdateSignKeyKES - :: forall m v. + :: forall v. ( ContextKES v ~ () - , RunIO m - , MonadFail m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Lock -> Proxy m -> Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property -prop_oneUpdateSignKeyKES lock _ _ seedPSB = - ioProperty . withLock lock . io . withMLockedSeedFromPSB seedPSB $ \seed -> do - sk <- genKeyKES @m @v seed - msk' <- updateKES @m () sk 0 + => Lock -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_oneUpdateSignKeyKES lock seedPSB = + ioProperty . withLock lock . withMLockedSeedFromPSB seedPSB $ \seed -> do + sk <- genKeyKES @v seed + msk' <- updateKES () sk 0 forgetSignKeyKES sk maybe (return ()) forgetSignKeyKES msk' return True prop_allUpdatesSignKeyKES - :: forall m v. + :: forall v. ( ContextKES v ~ () - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Lock -> Proxy m -> Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property -prop_allUpdatesSignKeyKES lock _ _ seedPSB = - ioProperty . withLock lock . io $ do - void $ withAllUpdatesKES_ @m @v seedPSB $ const (return ()) + => Lock -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_allUpdatesSignKeyKES lock seedPSB = + ioProperty . withLock lock $ do + void $ withAllUpdatesKES_ @v seedPSB $ const (return ()) -- | If we start with a signing key, we can evolve it a number of times so that -- the total number of signing keys (including the initial one) equals the -- total number of periods for this algorithm. -- prop_totalPeriodsKES - :: forall m v. + :: forall v. ( ContextKES v ~ () - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Lock -> Proxy m -> Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property -prop_totalPeriodsKES lock _ _ seed = + => Lock -> PinnedSizedBytes (SeedSizeKES v) -> Property +prop_totalPeriodsKES lock seed = ioProperty . withLock lock $ do - sks <- io $ withAllUpdatesKES_ @m @v seed (const . return $ ()) + sks <- withAllUpdatesKES_ @v seed (const . return $ ()) return $ totalPeriods > 0 ==> counterexample (show totalPeriods) $ @@ -534,25 +452,20 @@ prop_totalPeriodsKES lock _ _ seed = -- keys we derive from each one are the same. -- prop_deriveVerKeyKES - :: forall m v. + :: forall v. ( ContextKES v ~ () - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Proxy m -> Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Property -prop_deriveVerKeyKES _ _ seedPSB = + => PinnedSizedBytes (SeedSizeKES v) -> Property +prop_deriveVerKeyKES seedPSB = ioProperty $ do - vk_0 <- io $ do - sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @m @v - vk_0 <- deriveVerKeyKES @m sk_0 + vk_0 <- do + sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @v + vk_0 <- deriveVerKeyKES sk_0 forgetSignKeyKES sk_0 return vk_0 - vks <- io $ withAllUpdatesKES_ seedPSB $ deriveVerKeyKES @m + vks <- withAllUpdatesKES_ seedPSB deriveVerKeyKES return $ counterexample (show vks) $ conjoin (map (vk_0 ===) vks) @@ -563,25 +476,20 @@ prop_deriveVerKeyKES _ _ seedPSB = -- corresponding period. -- prop_verifyKES_positive - :: forall m v. + :: forall v. ( ContextKES v ~ () , Signable v ~ SignableRepresentation - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Proxy m -> Proxy v -> PinnedSizedBytes (SeedSizeKES v) -> Gen Property -prop_verifyKES_positive _ _ seedPSB = do + => PinnedSizedBytes (SeedSizeKES v) -> Gen Property +prop_verifyKES_positive seedPSB = do xs :: [Message] <- vectorOf totalPeriods arbitrary return $ checkCoverage $ cover 1 (length xs >= totalPeriods) "Message count covers total periods" $ not (null xs) ==> - ioProperty $ fmap conjoin $ io $ do - sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @m @v - vk <- deriveVerKeyKES @m sk_0 + ioProperty $ fmap conjoin $ do + sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @v + vk <- deriveVerKeyKES sk_0 forgetSignKeyKES sk_0 withAllUpdatesKES seedPSB $ \t sk -> do let x = cycle xs !! fromIntegral t @@ -600,24 +508,18 @@ prop_verifyKES_positive _ _ seedPSB = do -- corresponding to a different signing key, then the verification fails. -- prop_verifyKES_negative_key - :: forall m v. + :: forall v. ( ContextKES v ~ () , Signable v ~ SignableRepresentation - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Proxy m -> Proxy v - -> PinnedSizedBytes (SeedSizeKES v) + => PinnedSizedBytes (SeedSizeKES v) -> PinnedSizedBytes (SeedSizeKES v) -> Message -> Property -prop_verifyKES_negative_key _ _ seedPSB seedPSB' x = - seedPSB /= seedPSB' ==> ioProperty $ fmap conjoin $ io $ do - sk_0' <- withMLockedSeedFromPSB seedPSB' $ genKeyKES @m @v +prop_verifyKES_negative_key seedPSB seedPSB' x = + seedPSB /= seedPSB' ==> ioProperty $ fmap conjoin $ do + sk_0' <- withMLockedSeedFromPSB seedPSB' $ genKeyKES @v vk' <- deriveVerKeyKES sk_0' forgetSignKeyKES sk_0' withAllUpdatesKES seedPSB $ \t sk -> do @@ -632,24 +534,18 @@ prop_verifyKES_negative_key _ _ seedPSB seedPSB' x = -- verification fails. -- prop_verifyKES_negative_message - :: forall m v. + :: forall v. ( ContextKES v ~ () , Signable v ~ SignableRepresentation - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Proxy m -> Proxy v - -> PinnedSizedBytes (SeedSizeKES v) + => PinnedSizedBytes (SeedSizeKES v) -> Message -> Message -> Property -prop_verifyKES_negative_message _ _ seedPSB x x' = - x /= x' ==> ioProperty $ fmap conjoin $ io $ do - sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @m @v - vk <- deriveVerKeyKES @m sk_0 +prop_verifyKES_negative_message seedPSB x x' = + x /= x' ==> ioProperty $ fmap conjoin $ do + sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @v + vk <- deriveVerKeyKES sk_0 forgetSignKeyKES sk_0 withAllUpdatesKES seedPSB $ \t sk -> do sig <- signKES () t x sk @@ -664,24 +560,18 @@ prop_verifyKES_negative_message _ _ seedPSB x x' = -- verification fails. -- prop_verifyKES_negative_period - :: forall m v. + :: forall v. ( ContextKES v ~ () , Signable v ~ SignableRepresentation - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Proxy m -> Proxy v - -> PinnedSizedBytes (SeedSizeKES v) + => PinnedSizedBytes (SeedSizeKES v) -> Message -> Property -prop_verifyKES_negative_period _ _ seedPSB x = - ioProperty $ fmap conjoin $ io $ do - sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @m @v - vk <- deriveVerKeyKES @m sk_0 +prop_verifyKES_negative_period seedPSB x = + ioProperty $ fmap conjoin $ do + sk_0 <- withMLockedSeedFromPSB seedPSB $ genKeyKES @v + vk <- deriveVerKeyKES sk_0 forgetSignKeyKES sk_0 withAllUpdatesKES seedPSB $ \t sk -> do sig <- signKES () t x sk @@ -700,22 +590,16 @@ prop_verifyKES_negative_period _ _ seedPSB x = -- for 'VerKeyKES' on /all/ the KES key evolutions. -- prop_serialise_VerKeyKES - :: forall m v. + :: forall v. ( ContextKES v ~ () - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Proxy m -> Proxy v - -> PinnedSizedBytes (SeedSizeKES v) + => PinnedSizedBytes (SeedSizeKES v) -> Property -prop_serialise_VerKeyKES _ _ seedPSB = - ioProperty $ fmap conjoin $ io $ do - withAllUpdatesKES @m @v seedPSB $ \t sk -> do - vk <- deriveVerKeyKES @m sk +prop_serialise_VerKeyKES seedPSB = + ioProperty $ fmap conjoin $ do + withAllUpdatesKES @v seedPSB $ \t sk -> do + vk <- deriveVerKeyKES sk return $ counterexample ("period " ++ show t) $ counterexample ("vkey " ++ show vk) $ @@ -730,24 +614,18 @@ prop_serialise_VerKeyKES _ _ seedPSB = -- for 'SigKES' on /all/ the KES key evolutions. -- prop_serialise_SigKES - :: forall m v. + :: forall v. ( ContextKES v ~ () , Signable v ~ SignableRepresentation , Show (SignKeyKES v) - , RunIO m - , MonadIO m - , MonadSodium m - , MonadST m - , MonadThrow m - , KESSignAlgorithm m v + , KESSignAlgorithm v ) - => Proxy m -> Proxy v - -> PinnedSizedBytes (SeedSizeKES v) + => PinnedSizedBytes (SeedSizeKES v) -> Message -> Property -prop_serialise_SigKES _ _ seedPSB x = - ioProperty $ fmap conjoin $ io $ do - withAllUpdatesKES @m @v seedPSB $ \t sk -> do +prop_serialise_SigKES seedPSB x = + ioProperty $ fmap conjoin $ do + withAllUpdatesKES @v seedPSB $ \t sk -> do sig <- signKES () t x sk return $ counterexample ("period " ++ show t) $ @@ -764,34 +642,28 @@ prop_serialise_SigKES _ _ seedPSB x = -- KES test utils -- -withAllUpdatesKES_ :: forall m v a. - ( KESSignAlgorithm m v +withAllUpdatesKES_ :: forall v a. + ( KESSignAlgorithm v , ContextKES v ~ () - , MonadSodium m - , MonadST m - , MonadThrow m ) => PinnedSizedBytes (SeedSizeKES v) - -> (SignKeyKES v -> m a) - -> m [a] + -> (SignKeyKES v -> IO a) + -> IO [a] withAllUpdatesKES_ seedPSB f = do withAllUpdatesKES seedPSB (const f) -withAllUpdatesKES :: forall m v a. - ( KESSignAlgorithm m v +withAllUpdatesKES :: forall v a. + ( KESSignAlgorithm v , ContextKES v ~ () - , MonadSodium m - , MonadST m - , MonadThrow m ) => PinnedSizedBytes (SeedSizeKES v) - -> (Word -> SignKeyKES v -> m a) - -> m [a] + -> (Word -> SignKeyKES v -> IO a) + -> IO [a] withAllUpdatesKES seedPSB f = withMLockedSeedFromPSB seedPSB $ \seed -> do sk_0 <- genKeyKES seed go sk_0 0 where - go :: SignKeyKES v -> Word -> m [a] + go :: SignKeyKES v -> Word -> IO [a] go sk t = do x <- f t sk msk' <- updateKES () sk t diff --git a/cardano-crypto-tests/src/Test/Crypto/Util.hs b/cardano-crypto-tests/src/Test/Crypto/Util.hs index 1660f35dd..c0c7d7441 100644 --- a/cardano-crypto-tests/src/Test/Crypto/Util.hs +++ b/cardano-crypto-tests/src/Test/Crypto/Util.hs @@ -130,7 +130,7 @@ import qualified Test.QuickCheck.Gen as Gen import Control.Monad (guard, when) import GHC.TypeLits (Nat, KnownNat, natVal) import Formatting.Buildable (Buildable (..), build) -import Control.Concurrent.MVar (MVar, withMVar, newMVar) +import Control.Concurrent.Class.MonadMVar (MVar, withMVar, newMVar) import GHC.Stack (HasCallStack) -------------------------------------------------------------------------------- @@ -364,7 +364,7 @@ noExceptionsThrown = pure (property True) doesNotThrow :: Applicative m => m a -> m Property doesNotThrow = (*> noExceptionsThrown) -newtype Lock = Lock (MVar ()) +newtype Lock = Lock (MVar IO ()) withLock :: Lock -> IO a -> IO a withLock (Lock v) = withMVar v . const diff --git a/cardano-mempool/src/Cardano/Memory/Pool.hs b/cardano-mempool/src/Cardano/Memory/Pool.hs index d18aa9cfc..0be97d759 100644 --- a/cardano-mempool/src/Cardano/Memory/Pool.hs +++ b/cardano-mempool/src/Cardano/Memory/Pool.hs @@ -18,36 +18,39 @@ -- Currently there is no functionality for releasing unused pages. So, once a page is -- allocated, it will be re-used when more `Block`s is needed, but it will not be GCed -- until the whole `Pool` is GCed. -module Cardano.Memory.Pool - ( -- * Pool - Pool - , initPool +module Cardano.Memory.Pool ( + -- * Pool + Pool, + initPool, + -- * Block - , Block(..) - , blockByteCount - , grabNextBlock + Block (..), + blockByteCount, + grabNextBlock, + -- * Helpers + -- -- Exported for testing - , countPages - , findNextZeroIndex - ) where + countPages, + findNextZeroIndex, +) where -import Control.Monad -import Control.Monad.Primitive -import Foreign.Ptr -import Foreign.ForeignPtr -import GHC.ForeignPtr import Control.Applicative +import Control.Monad import Data.Bits -import GHC.TypeLits -import Data.Primitive.PrimArray +import Data.Primitive.MutVar import Data.Primitive.PVar import Data.Primitive.PVar.Unsafe (atomicModifyIntArray#) -import Data.IORef -import GHC.Int -import GHC.IO +import Data.Primitive.PrimArray +import Foreign.ForeignPtr +import Foreign.Ptr import GHC.Exts (fetchAndIntArray#) +import GHC.ForeignPtr (addForeignPtrConcFinalizer) +import GHC.IO +import GHC.Int +import GHC.ST +import GHC.TypeLits -- | This is just a proxy type that carries information at the type level about the size -- of the block in bytes supported by a particular instance of a `Pool`. Use @@ -61,78 +64,76 @@ blockByteCount = fromInteger . natVal -- | Internal helper type that manages each individual page. This is essentailly a mutable -- linked list, which contains a memory buffer, a bit array that tracks which blocks in -- the buffere are free and which ones are taken. -data Page n = - Page - { pageMemory :: !(ForeignPtr (Block n)) - -- ^ Contiguous memory buffer that holds all the blocks in the page. - , pageBitArray :: !(MutablePrimArray RealWorld Int) - -- ^ We use an Int array, because there are no built-in atomic primops for Word. - , pageFull :: !(PVar Int RealWorld) - -- ^ This is a boolean flag which indicates when a page is full. It here as - -- optimization only, because it allows us to skip iteration of the above bit - -- array. It is an `Int` instead of a `Bool`, because GHC provides atomic primops for - -- ByteArray, whcih is what `PVar` is based on. - , pageNextPage :: !(IORef (Maybe (Page n))) - -- ^ Link to the next page. Last page when this IORef contains `Nothing` - } +data Page n s = Page + { pageMemory :: !(ForeignPtr (Block n)) + -- ^ Contiguous memory buffer that holds all the blocks in the page. + , pageBitArray :: !(MutablePrimArray s Int) + -- ^ We use an Int array, because there are no built-in atomic primops for Word. + , pageFull :: !(PVar Int s) + -- ^ This is a boolean flag which indicates when a page is full. It here as + -- optimization only, because it allows us to skip iteration of the above bit + -- array. It is an `Int` instead of a `Bool`, because GHC provides atomic primops for + -- ByteArray, whcih is what `PVar` is based on. + , pageNextPage :: !(MutVar s (Maybe (Page n s))) + -- ^ Link to the next page. Last page when this IORef contains `Nothing` + } -- | Thread-safe lock-free memory pool for managing large memory pages that contain of -- many small `Block`s. -data Pool n = - Pool - { poolFirstPage :: !(Page n) - -- ^ Initial page, which itself contains references to subsequent pages - , poolPageInitializer :: !(IO (Page n)) - -- ^ Page initializing action - , poolBlockFinalizer :: !(Ptr (Block n) -> IO ()) - -- ^ Finilizer that will be attached to each individual `ForeignPtr` of a reserved - -- `Block`. - } +data Pool n s = Pool + { poolFirstPage :: !(Page n s) + -- ^ Initial page, which itself contains references to subsequent pages + , poolPageInitializer :: !(ST s (Page n s)) + -- ^ Page initializing action + , poolBlockFinalizer :: !(Ptr (Block n) -> IO ()) + -- ^ Finilizer that will be attached to each individual `ForeignPtr` of a reserved + -- `Block`. + } -- | Useful function for testing. Check how many pages have been allocated thus far. -countPages :: Pool n -> IO Int +countPages :: Pool n s -> ST s Int countPages pool = go 1 (poolFirstPage pool) where - go n Page {pageNextPage} = do - readIORef pageNextPage >>= \case + go n Page{pageNextPage} = do + readMutVar pageNextPage >>= \case Nothing -> pure n Just nextPage -> go (n + 1) nextPage - ixBitSize :: Int ixBitSize = finiteBitSize (0 :: Word) -- | Initilizes the `Pool` that can be used for further allocation of @`ForeignPtr` -- `Block` n@ with `grabNextBlock`. -initPool :: - forall n. KnownNat n +initPool + :: forall n s + . KnownNat n => Int -- ^ Number of groups per page. Must be a posititve number, otherwise error. One group -- contains as many blocks as the operating system has bits. A 64bit architecture will -- have 64 blocks per group. For example, if program is compiled on a 64 bit OS and you -- know ahead of time the maximum number of blocks that will be allocated through out -- the program, then the optimal value for this argument will @maxBlockNum/64@ - -> (forall a. Int -> IO (ForeignPtr a)) + -> (forall a. Int -> ST s (ForeignPtr a)) -- ^ Mempool page allocator. Some allocated pages might be immediately discarded, -- therefore number of pages utilized will not necessesarely match the number of times -- this action will be called. -> (Ptr (Block n) -> IO ()) -- ^ Finalizer to use for each block. It is an IO action because it will be executed by -- the Garbage Collector in a separate thread once the `Block` is no longer referenced. - -> IO (Pool n) + -> ST s (Pool n s) initPool groupsPerPage memAlloc blockFinalizer = do unless (groupsPerPage > 0) $ error $ - "Groups per page should be a positive number, but got: " ++ - show groupsPerPage + "Groups per page should be a positive number, but got: " + ++ show groupsPerPage let pageInit = do pageMemory <- memAlloc $ groupsPerPage * ixBitSize * blockByteCount (Block :: Block n) pageBitArray <- newPrimArray groupsPerPage setPrimArray pageBitArray 0 groupsPerPage 0 pageFull <- newPVar 0 - pageNextPage <- newIORef Nothing - pure Page {..} + pageNextPage <- newMutVar Nothing + pure Page{..} firstPage <- pageInit pure Pool @@ -145,45 +146,49 @@ initPool groupsPerPage memAlloc blockFinalizer = do -- finalizer attached to the `ForeignPtr` that will run `Block` pointer finalizer and -- release that memory for re-use by other blocks allocated in the future. It is safe to -- add more Haskell finalizers with `addForeignPtrConcFinalizer` if necessary. -grabNextBlock :: KnownNat n => Pool n -> IO (ForeignPtr (Block n)) +grabNextBlock :: KnownNat n => Pool n s -> ST s (ForeignPtr (Block n)) grabNextBlock = grabNextPoolBlockWith grabNextPageForeignPtr {-# INLINE grabNextBlock #-} -- | This is a helper function that will allocate a `Page` if the current `Page` in the -- `Pool` is full. Whenever there are still block slots are available then supplied -- @grabNext@ function will be used to reserve the slot in that `Page`. -grabNextPoolBlockWith :: - (Page n -> (Ptr (Block n) -> IO ()) -> IO (Maybe (ForeignPtr (Block n)))) - -> Pool n - -> IO (ForeignPtr (Block n)) +grabNextPoolBlockWith + :: (Page n s -> (Ptr (Block n) -> IO ()) -> ST s (Maybe (ForeignPtr (Block n)))) + -> Pool n s + -> ST s (ForeignPtr (Block n)) grabNextPoolBlockWith grabNext pool = go (poolFirstPage pool) where go page = do isPageFull <- atomicReadIntPVar (pageFull page) if intToBool isPageFull - then readIORef (pageNextPage page) >>= \case - Nothing -> do - newPage <- poolPageInitializer pool - -- There is a slight chance of a race condition in that the next page could - -- have been allocated and assigned to 'pageNextPage' by another thread - -- since we last checked for it. This is not a problem since we can safely - -- discard the page created in this thread and switch to the one that was - -- assigned to 'pageNextPage'. - mNextPage <- - atomicModifyIORef' (pageNextPage page) $ \mNextPage -> - (mNextPage <|> Just newPage, mNextPage) - case mNextPage of - Nothing -> go newPage - Just existingPage -> do - -- Here we cleanup the newly allocated page in favor of the one that - -- was potentially created by another thread. It is important to - -- eagerly free up scarce resources - finalizeForeignPtr (pageMemory newPage) - go existingPage - Just nextPage -> go nextPage - else grabNext page (poolBlockFinalizer pool) >>= \case - Nothing -> go page - Just ma -> pure ma + then + readMutVar (pageNextPage page) >>= \case + Nothing -> do + newPage <- poolPageInitializer pool + -- There is a slight chance of a race condition in that the next page could + -- have been allocated and assigned to 'pageNextPage' by another thread + -- since we last checked for it. This is not a problem since we can safely + -- discard the page created in this thread and switch to the one that was + -- assigned to 'pageNextPage'. + mNextPage <- + atomicModifyMutVar' (pageNextPage page) $ \mNextPage -> + (mNextPage <|> Just newPage, mNextPage) + case mNextPage of + Nothing -> go newPage + Just existingPage -> do + -- Here we cleanup the newly allocated page in favor of the one that + -- was potentially created by another thread. It is important to + -- eagerly free up scarce resources. + -- + -- This operation is idempotent and thread safe + unsafeIOToST $ finalizeForeignPtr (pageMemory newPage) + go existingPage + Just nextPage -> go nextPage + else + grabNext page (poolBlockFinalizer pool) >>= \case + Nothing -> go page + Just ma -> pure ma {-# INLINE grabNextPoolBlockWith #-} intToBool :: Int -> Bool @@ -193,14 +198,14 @@ intToBool _ = True -- | This is a helper function that will attempt to find the next available slot for the -- `Block` and create a `ForeignPtr` with the size of `Block` in the `Page`. In case when -- `Page` is full it will return `Nothing`. -grabNextPageForeignPtr :: - forall n. - KnownNat n - -- | Page to grab the block from - => Page n - -- | Finalizer to run, once the `ForeignPtr` holding on to `Ptr` `Block` is no longer used +grabNextPageForeignPtr + :: forall n s + . KnownNat n + => Page n s + -- ^ Page to grab the block from -> (Ptr (Block n) -> IO ()) - -> IO (Maybe (ForeignPtr (Block n))) + -- ^ Finalizer to run, once the `ForeignPtr` holding on to `Ptr` `Block` is no longer used + -> ST s (Maybe (ForeignPtr (Block n))) grabNextPageForeignPtr page finalizer = grabNextPageWithAllocator page $ \blockPtr resetIndex -> do fp <- newForeignPtr_ blockPtr @@ -208,12 +213,13 @@ grabNextPageForeignPtr page finalizer = pure fp {-# INLINE grabNextPageForeignPtr #-} -grabNextPageWithAllocator :: - forall n. KnownNat n - => Page n +grabNextPageWithAllocator + :: forall n s + . KnownNat n + => Page n s -> (Ptr (Block n) -> IO () -> IO (ForeignPtr (Block n))) - -> IO (Maybe (ForeignPtr (Block n))) -grabNextPageWithAllocator Page {..} allocator = do + -> ST s (Maybe (ForeignPtr (Block n))) +grabNextPageWithAllocator Page{..} allocator = do setNextZero pageBitArray >>= \case -- There is a slight chance that some Blocks will be cleared before the pageFull is -- set to True. This is not a problem because that memory will be recovered as soon as @@ -225,29 +231,30 @@ grabNextPageWithAllocator Page {..} allocator = do Nothing -> Nothing <$ atomicWriteIntPVar pageFull 1 Just ix -> fmap Just $ - withForeignPtr pageMemory $ \pagePtr -> - let !blockPtr = - plusPtr pagePtr $ ix * blockByteCount (Block :: Block n) - in allocator blockPtr $ do - let !(!q, !r) = ix `quotRem` ixBitSize - !pageBitMask = clearBit (complement 0) r - touch pageMemory - atomicAndIntMutablePrimArray pageBitArray q pageBitMask - atomicWriteIntPVar pageFull 0 + unsafeIOToST $ + withForeignPtr pageMemory $ \pagePtr -> + let !blockPtr = + plusPtr pagePtr $ ix * blockByteCount (Block :: Block n) + in allocator blockPtr $ do + let !(!q, !r) = ix `quotRem` ixBitSize + !pageBitMask = clearBit (complement 0) r + touchForeignPtr pageMemory + unsafeSTToIO $ atomicAndIntMutablePrimArray pageBitArray q pageBitMask + unsafeSTToIO $ atomicWriteIntPVar pageFull 0 {-# INLINE grabNextPageWithAllocator #-} -- | Atomically AND an element of the array -atomicAndIntMutablePrimArray :: MutablePrimArray RealWorld Int -> Int -> Int -> IO () +atomicAndIntMutablePrimArray :: MutablePrimArray s Int -> Int -> Int -> ST s () atomicAndIntMutablePrimArray (MutablePrimArray mba#) (I# i#) (I# m#) = - IO $ \s# -> + ST $ \s# -> case fetchAndIntArray# mba# i# m# s# of (# s'#, _ #) -> (# s'#, () #) {-# INLINE atomicAndIntMutablePrimArray #-} -- | Atomically modify an element of the array -atomicModifyMutablePrimArray :: MutablePrimArray RealWorld Int -> Int -> (Int -> (Int, a)) -> IO a +atomicModifyMutablePrimArray :: MutablePrimArray s Int -> Int -> (Int -> (Int, a)) -> ST s a atomicModifyMutablePrimArray (MutablePrimArray mba#) (I# i#) f = - IO $ atomicModifyIntArray# mba# i# (\x# -> case f (I# x#) of (I# y#, a) -> (# y#, a #)) + ST $ atomicModifyIntArray# mba# i# (\x# -> case f (I# x#) of (I# y#, a) -> (# y#, a #)) {-# INLINE atomicModifyMutablePrimArray #-} -- | Helper function that finds an index of the left-most bit that is not set. @@ -257,9 +264,10 @@ findNextZeroIndex b = i1 = countTrailingZeros (complement b) maxBits = finiteBitSize (undefined :: b) in if i0 == 0 - then if i1 == maxBits - then Nothing - else Just i1 + then + if i1 == maxBits + then Nothing + else Just i1 else Just (i0 - 1) {-# INLINE findNextZeroIndex #-} @@ -267,7 +275,7 @@ findNextZeroIndex b = -- atomically. In case when all bits are set, then `Nothing` is returned. It is possible -- that while search is ongoing bits that where checked get cleared. This is totally fine -- for our implementation of mempool. -setNextZero :: MutablePrimArray RealWorld Int -> IO (Maybe Int) +setNextZero :: MutablePrimArray s Int -> ST s (Maybe Int) setNextZero ma = ifindAtomicMutablePrimArray ma f where f i !w = @@ -276,18 +284,17 @@ setNextZero ma = ifindAtomicMutablePrimArray ma f Just !bitIx -> (setBit w bitIx, Just (ixBitSize * i + bitIx)) {-# INLINE setNextZero #-} - -ifindAtomicMutablePrimArray :: - MutablePrimArray RealWorld Int -> - (Int -> Int -> (Int, Maybe a)) -> - IO (Maybe a) +ifindAtomicMutablePrimArray + :: MutablePrimArray s Int + -> (Int -> Int -> (Int, Maybe a)) + -> ST s (Maybe a) ifindAtomicMutablePrimArray ma f = do n <- getSizeofMutablePrimArray ma let go i | i >= n = pure Nothing | otherwise = - atomicModifyMutablePrimArray ma i (f i) >>= \case - Nothing -> go (i + 1) - Just a -> pure $! Just a + atomicModifyMutablePrimArray ma i (f i) >>= \case + Nothing -> go (i + 1) + Just a -> pure $ Just a go 0 {-# INLINE ifindAtomicMutablePrimArray #-} diff --git a/cardano-mempool/tests/Test/Cardano/Memory/PoolTests.hs b/cardano-mempool/tests/Test/Cardano/Memory/PoolTests.hs index fe39a365f..d22a72a15 100644 --- a/cardano-mempool/tests/Test/Cardano/Memory/PoolTests.hs +++ b/cardano-mempool/tests/Test/Cardano/Memory/PoolTests.hs @@ -75,10 +75,10 @@ propFindNextZeroIndex w = monadicIO . run $ -- We allow one extra page be allocated due to concurrency false positives in block -- reservations -checkNumPages :: Pool n -> Int -> Int -> Assertion +checkNumPages :: Pool n RealWorld -> Int -> Int -> Assertion checkNumPages pool n numBlocks = do let estimatedUpperBoundOfPages = 1 + max 1 (numBlocks `div` n `div` 64) - numPages <- countPages pool + numPages <- stToPrim $ countPages pool assertBool (concat [ "Number of pages should not exceed the expected amount: " @@ -102,8 +102,8 @@ checkBlockBytes block byte ptr = checkFillByte (i - 1) in checkFillByte (blockByteCount block - 1) -mallocPreFilled :: Word8 -> Int -> IO (ForeignPtr b) -mallocPreFilled preFillByte bc = do +mallocPreFilled :: Word8 -> Int -> ST s (ForeignPtr b) +mallocPreFilled preFillByte bc = unsafeIOToPrim $ do mfp <- mallocForeignPtrBytes bc withForeignPtr mfp $ \ptr -> setPtr (castPtr ptr) bc preFillByte pure mfp @@ -166,11 +166,11 @@ propPoolGarbageCollected block (Positive n) numBlocks16 preFillByte fillByte = (pool, ptrs) <- ensureAllGCed numBlocks $ \countOneBlockGCed -> do pool <- - initPool n (mallocPreFilled preFillByte) $ \ptr -> do + stToPrim $ initPool n (mallocPreFilled preFillByte) $ \ptr -> do setPtr (castPtr ptr) (blockByteCount block) fillByte countOneBlockGCed fmps :: [ForeignPtr (Block n)] <- - replicateConcurrently numBlocks (grabNextBlock pool) + replicateConcurrently numBlocks (stToPrim $ grabNextBlock pool) touch fmps -- Here we return just the pointers and let the GC collect the ForeignPtrs ptrs <- @@ -201,14 +201,14 @@ propPoolAllocateAndFinalize block (Positive n) numBlocks16 emptyByte fullByte = ensureAllGCed numBlocks $ \countOneBlockGCed -> do chan <- newChan pool <- - initPool n (mallocPreFilled emptyByte) $ \(ptr :: Ptr (Block n)) -> do + stToPrim $ initPool n (mallocPreFilled emptyByte) $ \(ptr :: Ptr (Block n)) -> do setPtr (castPtr ptr) (blockByteCount block) emptyByte countOneBlockGCed -- allocate and finalize blocks concurrently pool <$ concurrently_ (do replicateConcurrently_ numBlocks $ do - fp <- grabNextBlock pool + fp <- stToPrim $ grabNextBlock pool withForeignPtr fp (checkBlockBytes block emptyByte) writeChan chan (Just fp) -- place Nothing to indicate that we are done allocating blocks From 0c6c14f7bb34d69695b28e2ec2111c36ebb0ac1e Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Mon, 22 May 2023 13:10:07 +0200 Subject: [PATCH 2/6] Move EqST to test suite --- .../cardano-crypto-class.cabal | 1 - .../src/Cardano/Crypto/DSIGN/Ed25519ML.hs | 6 ---- .../src/Cardano/Crypto/KES/Simple.hs | 8 ----- .../Crypto/Libsodium/MLockedBytes/Internal.hs | 4 --- .../Cardano/Crypto/Libsodium/MLockedSeed.hs | 8 ----- .../cardano-crypto-tests.cabal | 2 ++ cardano-crypto-tests/src/Test/Crypto/DSIGN.hs | 5 +-- .../src/Test}/Crypto/EqST.hs | 31 +++++++++++++++++-- cardano-crypto-tests/src/Test/Crypto/KES.hs | 2 +- 9 files changed, 35 insertions(+), 32 deletions(-) rename {cardano-crypto-class/src/Cardano => cardano-crypto-tests/src/Test}/Crypto/EqST.hs (69%) diff --git a/cardano-crypto-class/cardano-crypto-class.cabal b/cardano-crypto-class/cardano-crypto-class.cabal index 4451c600e..79870ac34 100644 --- a/cardano-crypto-class/cardano-crypto-class.cabal +++ b/cardano-crypto-class/cardano-crypto-class.cabal @@ -79,7 +79,6 @@ library Cardano.Crypto.Libsodium.MLockedBytes.Internal Cardano.Crypto.Libsodium.MLockedSeed Cardano.Crypto.Libsodium.UnsafeC - Cardano.Crypto.EqST Cardano.Crypto.PinnedSizedBytes Cardano.Crypto.Seed Cardano.Crypto.Util diff --git a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs index a4cbc2e2f..097fa5e5e 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/DSIGN/Ed25519ML.hs @@ -56,9 +56,6 @@ import Cardano.Crypto.PinnedSizedBytes , psbFromByteStringCheck , psbCreateSizedResult ) -import Cardano.Crypto.EqST - ( EqST (..) - ) import Cardano.Crypto.DSIGNM.Class import Cardano.Crypto.Libsodium.MLockedSeed @@ -256,9 +253,6 @@ instance DSIGNMAlgorithm Ed25519DSIGNM where -- forgetSignKeyDSIGNMWith _ (SignKeyEd25519DSIGNM sk) = mlsbFinalize sk -deriving via (MLockedSizedBytes (SizeSignKeyDSIGNM Ed25519DSIGNM)) - instance EqST (SignKeyDSIGNM Ed25519DSIGNM) - instance UnsoundDSIGNMAlgorithm Ed25519DSIGNM where -- -- Ser/deser (dangerous - do not use in production code) diff --git a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs index c40bcffc1..8ec44fa14 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/KES/Simple.hs @@ -43,8 +43,6 @@ import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.Libsodium.MLockedBytes import Cardano.Crypto.Util import Data.Unit.Strict (forceElemsToWHNF) -import Cardano.Crypto.EqST (EqST (..)) - data SimpleKES d (t :: Nat) @@ -218,12 +216,6 @@ deriving instance DSIGNMAlgorithmBase d => Show (SigKES (SimpleKES d t)) deriving instance DSIGNMAlgorithmBase d => Eq (VerKeyKES (SimpleKES d t)) deriving instance DSIGNMAlgorithmBase d => Eq (SigKES (SimpleKES d t)) -instance EqST (SignKeyDSIGNM d) => EqST (SignKeyKES (SimpleKES d t)) where - equalsM (ThunkySignKeySimpleKES a) (ThunkySignKeySimpleKES b) = - -- No need to check that lengths agree, the types already guarantee this. - Vec.and <$> Vec.zipWithM equalsM a b - - instance DSIGNMAlgorithmBase d => NoThunks (SigKES (SimpleKES d t)) instance DSIGNMAlgorithmBase d => NoThunks (SignKeyKES (SimpleKES d t)) instance DSIGNMAlgorithmBase d => NoThunks (VerKeyKES (SimpleKES d t)) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs index 1d50ac25f..3c7ca290a 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedBytes/Internal.hs @@ -55,7 +55,6 @@ import Cardano.Foreign import Cardano.Crypto.Libsodium.Memory import Cardano.Crypto.Libsodium.Memory.Internal (MLockedForeignPtr (..)) import Cardano.Crypto.Libsodium.C -import Cardano.Crypto.EqST import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BSI @@ -84,9 +83,6 @@ instance KnownNat n => Show (MLockedSizedBytes n) where -- hexstr = concatMap (printf "%02x") bytes -- in "MLSB " ++ hexstr -instance KnownNat n => EqST (MLockedSizedBytes n) where - equalsM = mlsbEq - nextPowerOf2 :: forall n. (Num n, Ord n, Bits n) => n -> n nextPowerOf2 i = go 1 diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs index 0677a9a13..5fb8c600d 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/MLockedSeed.hs @@ -20,9 +20,6 @@ import Cardano.Crypto.Libsodium.Memory ( MLockedAllocator, mlockedMalloc, ) -import Cardano.Crypto.EqST ( - EqST (..), - ) import Cardano.Foreign (SizedPtr) import Control.DeepSeq (NFData) import Control.Monad.Class.MonadST (MonadST) @@ -37,11 +34,6 @@ import NoThunks.Class (NoThunks) newtype MLockedSeed n = MLockedSeed {mlockedSeedMLSB :: MLockedSizedBytes n} deriving (NFData, NoThunks) -deriving via - MLockedSizedBytes n - instance - KnownNat n => EqST (MLockedSeed n) - withMLockedSeedAsMLSB :: Functor m => (MLockedSizedBytes n -> m (MLockedSizedBytes n)) diff --git a/cardano-crypto-tests/cardano-crypto-tests.cabal b/cardano-crypto-tests/cardano-crypto-tests.cabal index 040043a14..77c9596e1 100644 --- a/cardano-crypto-tests/cardano-crypto-tests.cabal +++ b/cardano-crypto-tests/cardano-crypto-tests.cabal @@ -54,6 +54,7 @@ library Test.Crypto.VRF Test.Crypto.Regressions Test.Crypto.Instances + Test.Crypto.EqST Bench.Crypto.DSIGN Bench.Crypto.VRF Bench.Crypto.KES @@ -86,6 +87,7 @@ library , criterion , base16-bytestring , tasty-hunit + , vector if flag(secp256k1-support) cpp-options: -DSECP256K1_ENABLED diff --git a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs index fcb969906..85de68852 100644 --- a/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs +++ b/cardano-crypto-tests/src/Test/Crypto/DSIGN.hs @@ -33,7 +33,6 @@ import Test.Tasty.QuickCheck (testProperty, QuickCheckTests) import qualified Data.ByteString as BS import Cardano.Crypto.Libsodium -import Cardano.Crypto.EqST (EqST (..), (==!)) import Text.Show.Pretty (ppShow) @@ -133,9 +132,11 @@ import Test.Crypto.Util ( Lock, withLock, ) -import Test.Crypto.Instances (withMLockedSeedFromPSB) import Cardano.Crypto.Libsodium.MLockedSeed +import Test.Crypto.Instances (withMLockedSeedFromPSB) +import Test.Crypto.EqST (EqST (..), (==!)) + #ifdef SECP256K1_ENABLED import Cardano.Crypto.DSIGN ( EcdsaSecp256k1DSIGN, diff --git a/cardano-crypto-class/src/Cardano/Crypto/EqST.hs b/cardano-crypto-tests/src/Test/Crypto/EqST.hs similarity index 69% rename from cardano-crypto-class/src/Cardano/Crypto/EqST.hs rename to cardano-crypto-tests/src/Test/Crypto/EqST.hs index ffdae7e0b..44fabcfdf 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/EqST.hs +++ b/cardano-crypto-tests/src/Test/Crypto/EqST.hs @@ -1,10 +1,21 @@ -{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -module Cardano.Crypto.EqST where +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE StandaloneDeriving #-} +module Test.Crypto.EqST where + +import GHC.TypeLits (KnownNat) +import qualified Data.Vector as Vec import Control.Monad.Class.MonadST (MonadST) +import Cardano.Crypto.Libsodium.MLockedBytes.Internal +import Cardano.Crypto.Libsodium.MLockedSeed +import Cardano.Crypto.DSIGN.Ed25519ML +import Cardano.Crypto.DSIGNM.Class +import Cardano.Crypto.KES.Simple + -- | Monadic flavor of 'Eq', for things that can only be compared in a monadic -- context that satisfies 'MonadST'. -- This is needed because we cannot have a sound 'Eq' instance on mlocked @@ -56,3 +67,19 @@ newtype PureEqST a = PureEqST a instance Eq a => EqST (PureEqST a) where equalsM (PureEqST a) (PureEqST b) = pure (a == b) + +instance KnownNat n => EqST (MLockedSizedBytes n) where + equalsM = mlsbEq + +deriving via + MLockedSizedBytes n + instance + KnownNat n => EqST (MLockedSeed n) + +deriving via (MLockedSizedBytes (SizeSignKeyDSIGNM Ed25519DSIGNM)) + instance EqST (SignKeyDSIGNM Ed25519DSIGNM) + +instance EqST (SignKeyDSIGNM d) => EqST (SignKeyKES (SimpleKES d t)) where + equalsM (ThunkySignKeySimpleKES a) (ThunkySignKeySimpleKES b) = + -- No need to check that lengths agree, the types already guarantee this. + Vec.and <$> Vec.zipWithM equalsM a b diff --git a/cardano-crypto-tests/src/Test/Crypto/KES.hs b/cardano-crypto-tests/src/Test/Crypto/KES.hs index 4212f3188..3b9b80780 100644 --- a/cardano-crypto-tests/src/Test/Crypto/KES.hs +++ b/cardano-crypto-tests/src/Test/Crypto/KES.hs @@ -45,7 +45,6 @@ import Cardano.Crypto.Util (SignableRepresentation(..)) import Cardano.Crypto.Libsodium import Cardano.Crypto.Libsodium.MLockedSeed import Cardano.Crypto.PinnedSizedBytes -import Cardano.Crypto.EqST import Test.QuickCheck import Test.Tasty (TestTree, testGroup, adjustOption) @@ -69,6 +68,7 @@ import Test.Crypto.Util ( Lock, withLock, ) +import Test.Crypto.EqST import Test.Crypto.Instances (withMLockedSeedFromPSB) import Test.Crypto.AllocLog From aaf1b61599dbe996d7eb84da1b6ef3c808861e85 Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Mon, 22 May 2023 15:39:03 +0200 Subject: [PATCH 3/6] Remove dead code --- .../src/Cardano/Crypto/Libsodium/Memory.hs | 2 -- .../src/Cardano/Crypto/Libsodium/Memory/Internal.hs | 11 ----------- 2 files changed, 13 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index 3c04e37e5..9806f0ad8 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -8,8 +8,6 @@ module Cardano.Crypto.Libsodium.Memory ( -- * MLocked allocations mlockedMalloc, MLockedAllocator (..), - AllocatorEvent(..), - getAllocatorEvent, mlockedAlloca, mlockedAllocaSized, diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 79e4162dd..37c2a9cda 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -16,8 +16,6 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( -- * MLocked allocations mlockedMalloc, MLockedAllocator (..), - AllocatorEvent(..), - getAllocatorEvent, mlockedAlloca, mlockedAllocaSized, @@ -207,15 +205,6 @@ packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString packByteStringCStringLen (ptr, len) = withLiftST $ \lift -> lift . unsafeIOToST $ BS.packCStringLen (ptr, len) -data AllocatorEvent where - AllocatorEvent :: (Show e, Typeable e) => e -> AllocatorEvent - -instance Show AllocatorEvent where - show (AllocatorEvent e) = "(AllocatorEvent " ++ show e ++ ")" - -getAllocatorEvent :: forall e. Typeable e => AllocatorEvent -> Maybe e -getAllocatorEvent (AllocatorEvent e) = cast e - newtype MLockedAllocator m = MLockedAllocator { mlAllocate :: forall a. CSize -> m (MLockedForeignPtr a) From 3ec7bf55f68dc95a532bc4213fd4a8eda62ab62f Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Wed, 24 May 2023 11:09:08 +0200 Subject: [PATCH 4/6] Address review comments --- .../src/Cardano/Crypto/Libsodium/Memory.hs | 2 +- .../Crypto/Libsodium/Memory/Internal.hs | 21 ++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index 9806f0ad8..4d681b11a 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -26,7 +26,7 @@ module Cardano.Crypto.Libsodium.Memory ( allocaBytes, -- * ByteString memory access, generalized to 'MonadST' - useByteStringAsCStringLen, + unpackByteStringCStringLen, packByteStringCStringLen, ) where diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 37c2a9cda..729fcd24b 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -34,7 +34,7 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( allocaBytes, -- * ByteString memory access, generalized to 'MonadST' - useByteStringAsCStringLen, + unpackByteStringCStringLen, packByteStringCStringLen, -- * Helper @@ -53,6 +53,7 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Unsafe as BS import Data.Coerce (coerce) import Data.Typeable +import Data.Word (Word8) import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) import Foreign.C.String (CStringLen) @@ -63,7 +64,7 @@ import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import qualified Foreign.Marshal.Alloc as Foreign import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) -import Foreign.Storable (Storable (peek), sizeOf, alignment) +import Foreign.Storable (Storable (peek), sizeOf, alignment, pokeByteOff) import GHC.IO.Exception (ioException) import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -194,11 +195,15 @@ allocaBytes :: Int -> (Ptr a -> ST s b) -> ST s b allocaBytes size f = unsafeIOToST $ Foreign.allocaBytes size (unsafeSTToIO . f) -useByteStringAsCStringLen :: ByteString -> (CStringLen -> ST s a) -> ST s a -useByteStringAsCStringLen bs f = - allocaBytes (BS.length bs + 1) $ \buf -> do - len <- unsafeIOToST $ BS.unsafeUseAsCStringLen bs $ \(ptr, len) -> - len <$ copyMem buf ptr (fromIntegral len) +-- | Unpacks a ByteString into a temporary buffer and runs the provided 'ST' +-- function on it. +unpackByteStringCStringLen :: ByteString -> (CStringLen -> ST s a) -> ST s a +unpackByteStringCStringLen bs f = do + let len = BS.length bs + allocaBytes (len + 1) $ \buf -> do + unsafeIOToST $ BS.unsafeUseAsCString bs $ \ptr -> do + copyMem buf ptr (fromIntegral len) + pokeByteOff buf len (0 :: Word8) f (buf, len) packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString @@ -226,6 +231,8 @@ mlockedAllocForeignPtrBytes :: MonadST m => CSize -> CSize -> m (MLockedForeignP mlockedAllocForeignPtrBytes = mlockedAllocForeignPtrBytesWith mlockedMalloc mlockedAllocForeignPtrBytesWith :: MLockedAllocator m -> CSize -> CSize -> m (MLockedForeignPtr a) +mlockedAllocForeignPtrBytesWith _ _ 0 = + error "Zero alignment" mlockedAllocForeignPtrBytesWith allocator size align = do mlAllocate allocator size' where From 061659d1d4470cd4d80e4234e3e62801d5cbce9c Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Wed, 21 Jun 2023 09:31:37 +0200 Subject: [PATCH 5/6] Remove unpackByteStringCStringLen --- .../src/Cardano/Crypto/Libsodium/Memory.hs | 1 - .../Cardano/Crypto/Libsodium/Memory/Internal.hs | 16 +--------------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs index 4d681b11a..a4405ef5d 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory.hs @@ -26,7 +26,6 @@ module Cardano.Crypto.Libsodium.Memory ( allocaBytes, -- * ByteString memory access, generalized to 'MonadST' - unpackByteStringCStringLen, packByteStringCStringLen, ) where diff --git a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs index 729fcd24b..b32854db4 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Libsodium/Memory/Internal.hs @@ -34,7 +34,6 @@ module Cardano.Crypto.Libsodium.Memory.Internal ( allocaBytes, -- * ByteString memory access, generalized to 'MonadST' - unpackByteStringCStringLen, packByteStringCStringLen, -- * Helper @@ -50,10 +49,8 @@ import Control.Monad.ST import Control.Monad.ST.Unsafe (unsafeIOToST, unsafeSTToIO) import Data.ByteString (ByteString) import qualified Data.ByteString as BS -import qualified Data.ByteString.Unsafe as BS import Data.Coerce (coerce) import Data.Typeable -import Data.Word (Word8) import Debug.Trace (traceShowM) import Foreign.C.Error (errnoToIOError, getErrno) import Foreign.C.String (CStringLen) @@ -64,7 +61,7 @@ import Foreign.ForeignPtr.Unsafe (unsafeForeignPtrToPtr) import qualified Foreign.Marshal.Alloc as Foreign import Foreign.Marshal.Utils (fillBytes) import Foreign.Ptr (Ptr, nullPtr, castPtr) -import Foreign.Storable (Storable (peek), sizeOf, alignment, pokeByteOff) +import Foreign.Storable (Storable (peek), sizeOf, alignment) import GHC.IO.Exception (ioException) import GHC.TypeLits (KnownNat, natVal) import NoThunks.Class (NoThunks, OnlyCheckWhnfNamed (..)) @@ -195,17 +192,6 @@ allocaBytes :: Int -> (Ptr a -> ST s b) -> ST s b allocaBytes size f = unsafeIOToST $ Foreign.allocaBytes size (unsafeSTToIO . f) --- | Unpacks a ByteString into a temporary buffer and runs the provided 'ST' --- function on it. -unpackByteStringCStringLen :: ByteString -> (CStringLen -> ST s a) -> ST s a -unpackByteStringCStringLen bs f = do - let len = BS.length bs - allocaBytes (len + 1) $ \buf -> do - unsafeIOToST $ BS.unsafeUseAsCString bs $ \ptr -> do - copyMem buf ptr (fromIntegral len) - pokeByteOff buf len (0 :: Word8) - f (buf, len) - packByteStringCStringLen :: MonadST m => CStringLen -> m ByteString packByteStringCStringLen (ptr, len) = withLiftST $ \lift -> lift . unsafeIOToST $ BS.packCStringLen (ptr, len) From 4ee7879de865f9bd32280382d63203bfda5ffc1b Mon Sep 17 00:00:00 2001 From: Tobias Dammers Date: Wed, 21 Jun 2023 09:42:47 +0200 Subject: [PATCH 6/6] Add new functionality to changelog --- cardano-crypto-class/CHANGELOG.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cardano-crypto-class/CHANGELOG.md b/cardano-crypto-class/CHANGELOG.md index 0a8620dfd..a61b593b0 100644 --- a/cardano-crypto-class/CHANGELOG.md +++ b/cardano-crypto-class/CHANGELOG.md @@ -7,12 +7,19 @@ solidified. Ask @lehins if backport is needed. * Introduce memory locking and secure forgetting functionality: [#255](https://github.com/input-output-hk/cardano-base/pull/255) + [#404](https://github.com/input-output-hk/cardano-base/pull/404) * KES started using the new memlocking functionality: [#255](https://github.com/input-output-hk/cardano-base/pull/255) + [#404](https://github.com/input-output-hk/cardano-base/pull/404) * Introduction of `DSIGNM` that uses the new memlocking functionality: - [#255](https://github.com/input-output-hk/cardano-base/pull/255) + [#404](https://github.com/input-output-hk/cardano-base/pull/404) * Included bindings to `blst` library to enable operations over curve BLS12-381 [#266](https://github.com/input-output-hk/cardano-base/pull/266) +* Introduction of `DirectSerialise` / `DirectDeserialise` APIs, providing + direct access to mlocked keys in RAM: + [#404](https://github.com/input-output-hk/cardano-base/pull/404) +* Restructuring of libsodium bindings and related APIs: + [#404](https://github.com/input-output-hk/cardano-base/pull/404) ## 2.1.0.2