From fb9b71f3bc33f8de673c6427736f09bf7972e81f Mon Sep 17 00:00:00 2001 From: Alexey Kuleshevich Date: Mon, 16 Sep 2024 23:45:58 -0600 Subject: [PATCH] WIP: use MemPack --- cabal.project | 7 + .../cardano-crypto-class.cabal | 2 + .../src/Cardano/Crypto/Hash/Class.hs | 5 + .../src/Cardano/Crypto/PackedBytes.hs | 134 +++++++++++------- .../cardano-crypto-tests.cabal | 1 + cardano-crypto-tests/src/Test/Crypto/Hash.hs | 7 + 6 files changed, 105 insertions(+), 51 deletions(-) diff --git a/cabal.project b/cabal.project index e59ab314b..1a59b4f9f 100644 --- a/cabal.project +++ b/cabal.project @@ -57,3 +57,10 @@ if impl(ghc >= 9.8) location: https://github.com/recursion-schemes/recursion-schemes tag: cc2e88c3400a6548e975830c9addb12ab087545f --sha256: 06shyihy6cpblv3pf18xgdfjgxqw2y2awvpcy33r76fr642gdvgn + + +source-repository-package + type: git + location: https://github.com/lehins/mempack.git + tag: f07b53fbfc3c56d4d60e072e277ffdf655aee59e + --sha256: sha256-tkgPmpFQ2h5hX8gh3tQ5T5H756tmqgUNGb2hLQUgLWc= diff --git a/cardano-crypto-class/cardano-crypto-class.cabal b/cardano-crypto-class/cardano-crypto-class.cabal index 2a531846b..215462a5e 100644 --- a/cardano-crypto-class/cardano-crypto-class.cabal +++ b/cardano-crypto-class/cardano-crypto-class.cabal @@ -100,6 +100,8 @@ library , deepseq , heapwords , memory + , mempack + , mtl , nothunks , primitive , serialise diff --git a/cardano-crypto-class/src/Cardano/Crypto/Hash/Class.hs b/cardano-crypto-class/src/Cardano/Crypto/Hash/Class.hs index 3351a59e2..c1c8e8a90 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/Hash/Class.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/Hash/Class.hs @@ -5,8 +5,10 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -66,6 +68,7 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Base16 as Base16 import qualified Data.ByteString.Char8 as BSC import Data.ByteString.Short (ShortByteString) +import Data.MemPack import Data.Word (Word8) import Numeric.Natural (Natural) @@ -110,6 +113,8 @@ sizeHash _ = fromInteger (natVal (Proxy @(SizeHash h))) newtype Hash h a = UnsafeHashRep (PackedBytes (SizeHash h)) deriving (Eq, Ord, Generic, NoThunks, NFData) +deriving instance HashAlgorithm h => MemPack (Hash h a) + -- | This instance is meant to be used with @TemplateHaskell@ -- -- >>> import Cardano.Crypto.Hash.Class (Hash) diff --git a/cardano-crypto-class/src/Cardano/Crypto/PackedBytes.hs b/cardano-crypto-class/src/Cardano/Crypto/PackedBytes.hs index 07d200ae3..a1436ef7e 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/PackedBytes.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/PackedBytes.hs @@ -27,11 +27,15 @@ import Codec.Serialise.Encoding (encodeBytes) import Control.DeepSeq import Control.Monad (guard) import Control.Monad.Primitive +import Control.Monad.Reader +import Control.Monad.State.Strict import Data.Bits import Data.ByteString import Data.ByteString.Internal as BS (accursedUnutterablePerformIO, fromForeignPtr, toForeignPtr) import Data.ByteString.Short.Internal as SBS +import Data.MemPack +import Data.MemPack.Buffer import Data.Primitive.ByteArray import Data.Primitive.PrimArray (PrimArray(..), imapPrimArray, indexPrimArray) import Data.Typeable @@ -39,7 +43,6 @@ import Foreign.ForeignPtr import Foreign.Ptr (castPtr) import Foreign.Storable (peekByteOff) import GHC.Exts -import GHC.ForeignPtr (ForeignPtr(ForeignPtr), ForeignPtrContents(PlainPtr)) #if MIN_VERSION_base(4,15,0) import GHC.ForeignPtr (unsafeWithForeignPtr) #endif @@ -92,7 +95,38 @@ instance NFData (PackedBytes n) where rnf PackedBytes32 {} = () rnf PackedBytes# {} = () -instance Serialise (PackedBytes n) where +instance KnownNat n => MemPack (PackedBytes n) where + packedByteCount = fromIntegral @Integer @Int . natVal + {-# INLINE packedByteCount #-} + packM pb = do + let !len@(I# len#) = packedByteCount pb + i@(I# i#) <- state $ \i -> (i, i + len) + mba@(MutableByteArray mba#) <- ask + Pack $ \_ -> lift $ case pb of + PackedBytes8 w -> writeWord64BE mba i w + PackedBytes28 w0 w1 w2 w3 -> do + writeWord64BE mba i w0 + writeWord64BE mba (i + 8) w1 + writeWord64BE mba (i + 16) w2 + writeWord32BE mba (i + 24) w3 + PackedBytes32 w0 w1 w2 w3 -> do + writeWord64BE mba i w0 + writeWord64BE mba (i + 8) w1 + writeWord64BE mba (i + 16) w2 + writeWord64BE mba (i + 24) w3 + PackedBytes# ba# -> + st_ (copyByteArray# ba# 0# mba# i# len#) + {-# INLINE packM #-} + unpackM = do + let !len = fromIntegral @Integer @Int $ natVal' (proxy# :: Proxy# n) + curPos@(I# curPos#) <- guardAdvanceUnpack len + buf <- ask + pure $! buffer buf + (\ba# -> packBytes (SBS.SBS ba#) curPos) + (\addr# -> accursedUnutterablePerformIO $ packPinnedPtr (Ptr (addr# `plusAddr#` curPos#))) + {-# INLINE unpackM #-} + +instance KnownNat n => Serialise (PackedBytes n) where encode = encodeBytes . unpackPinnedBytes decode = packPinnedBytesN <$> decodeBytes @@ -221,53 +255,60 @@ packBytesMaybe bs offset = do Just $ packBytes bs offset -packPinnedBytes8 :: ByteString -> PackedBytes 8 -packPinnedBytes8 bs = unsafeWithByteStringPtr bs (fmap PackedBytes8 . (`peekWord64BE` 0)) -{-# INLINE packPinnedBytes8 #-} - -packPinnedBytes28 :: ByteString -> PackedBytes 28 -packPinnedBytes28 bs = - unsafeWithByteStringPtr bs $ \ptr -> - PackedBytes28 - <$> peekWord64BE ptr 0 - <*> peekWord64BE ptr 8 - <*> peekWord64BE ptr 16 - <*> peekWord32BE ptr 24 -{-# INLINE packPinnedBytes28 #-} - -packPinnedBytes32 :: ByteString -> PackedBytes 32 -packPinnedBytes32 bs = - unsafeWithByteStringPtr bs $ \ptr -> PackedBytes32 <$> peekWord64BE ptr 0 - <*> peekWord64BE ptr 8 - <*> peekWord64BE ptr 16 - <*> peekWord64BE ptr 24 -{-# INLINE packPinnedBytes32 #-} - -packPinnedBytesN :: ByteString -> PackedBytes n -packPinnedBytesN bs = - case toShort bs of - SBS ba# -> PackedBytes# ba# -{-# INLINE packPinnedBytesN #-} +packPinnedPtr8 :: Ptr a -> IO (PackedBytes 8) +packPinnedPtr8 = fmap PackedBytes8 . (`peekWord64BE` 0) +{-# INLINE packPinnedPtr8 #-} + +packPinnedPtr28 :: Ptr a -> IO (PackedBytes 28) +packPinnedPtr28 ptr = + PackedBytes28 + <$> peekWord64BE ptr 0 + <*> peekWord64BE ptr 8 + <*> peekWord64BE ptr 16 + <*> peekWord32BE ptr 24 +{-# INLINE packPinnedPtr28 #-} + +packPinnedPtr32 :: Ptr a -> IO (PackedBytes 32) +packPinnedPtr32 ptr = + PackedBytes32 <$> peekWord64BE ptr 0 + <*> peekWord64BE ptr 8 + <*> peekWord64BE ptr 16 + <*> peekWord64BE ptr 24 +{-# INLINE packPinnedPtr32 #-} + +packPinnedPtrN :: forall n a. KnownNat n => Ptr a -> IO (PackedBytes n) +packPinnedPtrN (Ptr addr#) = pure $! PackedBytes# ba# + where + !(ByteArray ba#) = withMutableByteArray len $ \(MutableByteArray mba#) -> + st_ (copyAddrToByteArray# addr# mba# 0# len#) + !len@(I# len#) = fromIntegral @Integer @Int (natVal' (proxy# :: Proxy# n)) +{-# INLINE packPinnedPtrN #-} +packPinnedBytesN :: KnownNat n => ByteString -> PackedBytes n +packPinnedBytesN bs = unsafeWithByteStringPtr bs packPinnedPtrN +{-# INLINE packPinnedBytesN #-} -packPinnedBytes :: forall n . KnownNat n => ByteString -> PackedBytes n -packPinnedBytes bs = +packPinnedPtr :: forall n a. KnownNat n => Ptr a -> IO (PackedBytes n) +packPinnedPtr bs = let px = Proxy :: Proxy n in case sameNat px (Proxy :: Proxy 8) of - Just Refl -> packPinnedBytes8 bs + Just Refl -> packPinnedPtr8 bs Nothing -> case sameNat px (Proxy :: Proxy 28) of - Just Refl -> packPinnedBytes28 bs + Just Refl -> packPinnedPtr28 bs Nothing -> case sameNat px (Proxy :: Proxy 32) of - Just Refl -> packPinnedBytes32 bs - Nothing -> packPinnedBytesN bs -{-# INLINE[1] packPinnedBytes #-} - + Just Refl -> packPinnedPtr32 bs + Nothing -> packPinnedPtrN bs +{-# INLINE[1] packPinnedPtr #-} {-# RULES -"packPinnedBytes8" packPinnedBytes = packPinnedBytes8 -"packPinnedBytes28" packPinnedBytes = packPinnedBytes28 -"packPinnedBytes32" packPinnedBytes = packPinnedBytes32 +"packPinnedPtr8" packPinnedPtr = packPinnedPtr8 +"packPinnedPtr28" packPinnedPtr = packPinnedPtr28 +"packPinnedPtr32" packPinnedPtr = packPinnedPtr32 #-} +packPinnedBytes :: forall n . KnownNat n => ByteString -> PackedBytes n +packPinnedBytes bs = unsafeWithByteStringPtr bs packPinnedPtr +{-# INLINE packPinnedBytes #-} + --- Primitive architecture agnostic helpers @@ -358,22 +399,13 @@ writeWord32BE (MutableByteArray mba#) (I# i#) w = #endif {-# INLINE writeWord32BE #-} -byteArrayToShortByteString :: ByteArray -> ShortByteString -byteArrayToShortByteString (ByteArray ba#) = SBS ba# -{-# INLINE byteArrayToShortByteString #-} - byteArrayToByteString :: ByteArray -> ByteString -byteArrayToByteString ba +byteArrayToByteString ba@(ByteArray ba#) | isByteArrayPinned ba = - BS.fromForeignPtr (pinnedByteArrayToForeignPtr ba) 0 (sizeofByteArray ba) + BS.fromForeignPtr (pinnedByteArrayToForeignPtr ba#) 0 (sizeofByteArray ba) | otherwise = SBS.fromShort (byteArrayToShortByteString ba) {-# INLINE byteArrayToByteString #-} -pinnedByteArrayToForeignPtr :: ByteArray -> ForeignPtr a -pinnedByteArrayToForeignPtr (ByteArray ba#) = - ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#)) -{-# INLINE pinnedByteArrayToForeignPtr #-} - -- Usage of `accursedUnutterablePerformIO` here is safe because we only use it -- for indexing into an immutable `ByteString`, which is analogous to -- `Data.ByteString.index`. Make sure you know what you are doing before using diff --git a/cardano-crypto-tests/cardano-crypto-tests.cabal b/cardano-crypto-tests/cardano-crypto-tests.cabal index 293a0a41a..40f1905b0 100644 --- a/cardano-crypto-tests/cardano-crypto-tests.cabal +++ b/cardano-crypto-tests/cardano-crypto-tests.cabal @@ -72,6 +72,7 @@ library , cryptonite , deepseq , formatting + , mempack , nothunks , pretty-show , QuickCheck diff --git a/cardano-crypto-tests/src/Test/Crypto/Hash.hs b/cardano-crypto-tests/src/Test/Crypto/Hash.hs index 052a17762..c68a99b36 100644 --- a/cardano-crypto-tests/src/Test/Crypto/Hash.hs +++ b/cardano-crypto-tests/src/Test/Crypto/Hash.hs @@ -16,6 +16,7 @@ import qualified Data.Bits as Bits (xor) import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import Data.Maybe (fromJust) +import Data.MemPack import Data.Proxy (Proxy(..)) import Data.String (fromString) import GHC.TypeLits @@ -62,9 +63,15 @@ testHashAlgorithm p = , testProperty "hashFromStringAsHex/fromString" $ prop_hash_hashFromStringAsHex_fromString @h @Float , testProperty "show/read" $ prop_hash_show_read @h @Float , testProperty "NoThunks" $ prop_no_thunks @(Hash h Int) + , testProperty "MemPack RoundTrip" $ prop_MemPackRoundTrip @(Hash h Int) ] where n = hashAlgorithmName p +prop_MemPackRoundTrip :: forall a. (MemPack a, Eq a, Show a) => a -> Property +prop_MemPackRoundTrip a = + unpackError (pack a) === a .&&. + unpackError (packByteString a) === a + testSodiumHashAlgorithm :: forall proxy h. NaCl.SodiumHashAlgorithm h => proxy h