Skip to content

Commit

Permalink
Merge pull request #109 from input-output-hk/jdral/factor-out-checked…
Browse files Browse the repository at this point in the history
…-strict-tvar

Factour out checked `StrictTVar`s
  • Loading branch information
jorisdral committed Jul 26, 2023
2 parents c1d12b6 + 52ea816 commit 8293db5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 103 deletions.
6 changes: 6 additions & 0 deletions strict-stm/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## next version

### Breaking changes

* Remove invariants for `StrictTVar`s.

## 1.1.0.1

### Non-breaking changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ import Control.Monad.Class.MonadSTM hiding (traceTMVar, traceTMVarIO)
type LazyTMVar m = Lazy.TMVar m

-- | 'TMVar' that keeps its value in WHNF at all times
--
-- Does not support an invariant: if the invariant would not be satisfied,
-- we would not be able to put a value into an empty TMVar, which would lead
-- to very hard to debug bugs where code is blocked indefinitely.
newtype StrictTMVar m a = StrictTMVar { toLazyTMVar :: LazyTMVar m a }

fromLazyTMVar :: LazyTMVar m a -> StrictTMVar m a
Expand Down
104 changes: 13 additions & 91 deletions strict-stm/src/Control/Concurrent/Class/MonadSTM/Strict/TVar.hs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | This module corresponds to `Control.Concurrent.STM.TVar` in "stm" package
--
Expand All @@ -16,17 +13,13 @@ module Control.Concurrent.Class.MonadSTM.Strict.TVar
, castStrictTVar
, newTVar
, newTVarIO
, newTVarWithInvariant
, newTVarWithInvariantIO
, readTVar
, readTVarIO
, writeTVar
, modifyTVar
, stateTVar
, swapTVar
, check
-- ** Low-level API
, checkInvariant
-- * MonadLabelSTM
, labelTVar
, labelTVarIO
Expand All @@ -38,22 +31,11 @@ module Control.Concurrent.Class.MonadSTM.Strict.TVar
import qualified Control.Concurrent.Class.MonadSTM.TVar as Lazy
import Control.Monad.Class.MonadSTM hiding (traceTVar, traceTVarIO)

import GHC.Stack
type LazyTVar m = Lazy.TVar m


type LazyTVar m = Lazy.TVar m

#if CHECK_TVAR_INVARIANT
data StrictTVar m a = StrictTVar
{ invariant :: !(a -> Maybe String)
-- ^ Invariant checked whenever updating the 'StrictTVar'.
, tvar :: !(LazyTVar m a)
}
#else
newtype StrictTVar m a = StrictTVar
{ tvar :: LazyTVar m a
}
#endif
newtype StrictTVar m a = StrictTVar {
tvar :: LazyTVar m a
}

labelTVar :: MonadLabelledSTM m => StrictTVar m a -> String -> STM m ()
labelTVar StrictTVar { tvar } = Lazy.labelTVar tvar
Expand All @@ -76,8 +58,7 @@ traceTVarIO StrictTVar {tvar} = Lazy.traceTVarIO tvar

castStrictTVar :: LazyTVar m ~ LazyTVar n
=> StrictTVar m a -> StrictTVar n a
castStrictTVar v@StrictTVar {tvar} =
mkStrictTVar (getInvariant v) tvar
castStrictTVar StrictTVar {tvar} = StrictTVar {tvar}

-- | Get the underlying @TVar@
--
Expand All @@ -87,50 +68,22 @@ toLazyTVar :: StrictTVar m a -> LazyTVar m a
toLazyTVar StrictTVar { tvar } = tvar

fromLazyTVar :: LazyTVar m a -> StrictTVar m a
fromLazyTVar tvar =
#if CHECK_TVAR_INVARIANT
StrictTVar { invariant = const Nothing
, tvar
}
#else
StrictTVar { tvar }
#endif
fromLazyTVar = StrictTVar

newTVar :: MonadSTM m => a -> STM m (StrictTVar m a)
newTVar !a = (\tvar -> mkStrictTVar (const Nothing) tvar)
<$> Lazy.newTVar a
newTVar !a = StrictTVar <$> Lazy.newTVar a

newTVarIO :: MonadSTM m => a -> m (StrictTVar m a)
newTVarIO = newTVarWithInvariantIO (const Nothing)

newTVarWithInvariant :: (MonadSTM m, HasCallStack)
=> (a -> Maybe String) -- ^ Invariant (expect 'Nothing')
-> a
-> STM m (StrictTVar m a)
newTVarWithInvariant invariant !a =
checkInvariant (invariant a) $
(\tvar -> mkStrictTVar invariant tvar)
<$> Lazy.newTVar a

newTVarWithInvariantIO :: (MonadSTM m, HasCallStack)
=> (a -> Maybe String) -- ^ Invariant (expect 'Nothing')
-> a
-> m (StrictTVar m a)
newTVarWithInvariantIO invariant !a =
checkInvariant (invariant a) $
(\tvar -> mkStrictTVar invariant tvar)
<$> Lazy.newTVarIO a
newTVarIO !a = StrictTVar <$> Lazy.newTVarIO a

readTVar :: MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar { tvar } = Lazy.readTVar tvar

readTVarIO :: MonadSTM m => StrictTVar m a -> m a
readTVarIO StrictTVar { tvar } = Lazy.readTVarIO tvar

writeTVar :: (MonadSTM m, HasCallStack) => StrictTVar m a -> a -> STM m ()
writeTVar v !a =
checkInvariant (getInvariant v a) $
Lazy.writeTVar (tvar v) a
writeTVar :: MonadSTM m => StrictTVar m a -> a -> STM m ()
writeTVar v !a = Lazy.writeTVar (tvar v) a

modifyTVar :: MonadSTM m => StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar v f = readTVar v >>= writeTVar v . f
Expand All @@ -147,34 +100,3 @@ swapTVar v a' = do
a <- readTVar v
writeTVar v a'
return a


{-------------------------------------------------------------------------------
Dealing with invariants
-------------------------------------------------------------------------------}

getInvariant :: StrictTVar m a -> a -> Maybe String
mkStrictTVar :: (a -> Maybe String) -> Lazy.TVar m a -> StrictTVar m a

-- | Check invariant (if enabled) before continuing
--
-- @checkInvariant mErr x@ is equal to @x@ if @mErr == Nothing@, and throws
-- an error @err@ if @mErr == Just err@.
--
-- This is exported so that other code that wants to conditionally check
-- invariants can reuse the same logic, rather than having to introduce new
-- per-package flags.
checkInvariant :: HasCallStack => Maybe String -> a -> a

#if CHECK_TVAR_INVARIANT
getInvariant StrictTVar {invariant} = invariant
mkStrictTVar invariant tvar = StrictTVar {invariant, tvar}

checkInvariant Nothing k = k
checkInvariant (Just err) _ = error $ "Invariant violation: " ++ err
#else
getInvariant _ = \_ -> Nothing
mkStrictTVar _invariant tvar = StrictTVar {tvar}

checkInvariant _err k = k
#endif
8 changes: 0 additions & 8 deletions strict-stm/strict-stm.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ source-repository head
location: https://github.com/input-output-hk/io-sim
subdir: strict-stm

flag checktvarinvariant
Description: Enable runtime invariant checks on StrictT(M)Var
Manual: True
Default: False

flag asserts
description: Enable assertions
manual: False
Expand Down Expand Up @@ -68,6 +63,3 @@ library

if flag(asserts)
ghc-options: -fno-ignore-asserts

if flag(checktvarinvariant)
cpp-options: -DCHECK_TVAR_INVARIANT

0 comments on commit 8293db5

Please sign in to comment.