{-# LANGUAGE LambdaCase #-}
module Grisette.Experimental.MonadParallelUnion
( MonadParallelUnion (..),
)
where
import Control.DeepSeq (NFData, force)
import Control.Monad.Except (ExceptT (ExceptT), runExceptT)
import Control.Monad.Identity (IdentityT (IdentityT, runIdentityT))
import qualified Control.Monad.RWS.Lazy as RWSLazy
import qualified Control.Monad.RWS.Strict as RWSStrict
import Control.Monad.Reader (ReaderT (ReaderT, runReaderT))
import qualified Control.Monad.State.Lazy as StateLazy
import qualified Control.Monad.State.Strict as StateStrict
import Control.Monad.Trans.Maybe (MaybeT (MaybeT, runMaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import Control.Parallel.Strategies (rpar, rseq, runEval)
import Grisette.Internal.Core.Control.Monad.Class.Union (MonadUnion)
import Grisette.Internal.Core.Control.Monad.Union (Union, unionBase)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, tryMerge)
import Grisette.Internal.Core.Data.UnionBase (UnionBase (UnionIf, UnionSingle))
class (MonadUnion m, TryMerge m) => MonadParallelUnion m where
parBindUnion :: (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
instance (MonadParallelUnion m) => MonadParallelUnion (MaybeT m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
MaybeT m a -> (a -> MaybeT m b) -> MaybeT m b
parBindUnion (MaybeT m (Maybe a)
x) a -> MaybeT m b
f =
m (Maybe b) -> MaybeT m b
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe b) -> MaybeT m b) -> m (Maybe b) -> MaybeT m b
forall a b. (a -> b) -> a -> b
$
m (Maybe a)
x m (Maybe a) -> (Maybe a -> m (Maybe b)) -> m (Maybe b)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
Maybe a
Nothing -> Maybe b -> m (Maybe b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe b
forall a. Maybe a
Nothing
Just a
x'' -> MaybeT m b -> m (Maybe b)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m b -> m (Maybe b)) -> MaybeT m b -> m (Maybe b)
forall a b. (a -> b) -> a -> b
$ a -> MaybeT m b
f a
x''
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable e, NFData e) => MonadParallelUnion (ExceptT e m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ExceptT e m a -> (a -> ExceptT e m b) -> ExceptT e m b
parBindUnion (ExceptT m (Either e a)
x) a -> ExceptT e m b
f =
m (Either e b) -> ExceptT e m b
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either e b) -> ExceptT e m b)
-> m (Either e b) -> ExceptT e m b
forall a b. (a -> b) -> a -> b
$
m (Either e a)
x m (Either e a) -> (Either e a -> m (Either e b)) -> m (Either e b)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
Left e
e -> Either e b -> m (Either e b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either e b -> m (Either e b)) -> Either e b -> m (Either e b)
forall a b. (a -> b) -> a -> b
$ e -> Either e b
forall a b. a -> Either a b
Left e
e
Right a
x'' -> ExceptT e m b -> m (Either e b)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT e m b -> m (Either e b))
-> ExceptT e m b -> m (Either e b)
forall a b. (a -> b) -> a -> b
$ a -> ExceptT e m b
f a
x''
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateLazy.StateT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateLazy.StateT s -> m (a, s)
x) a -> StateT s m b
f = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateLazy.StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s ->
s -> m (a, s)
x s
s m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(a
a, s
s') -> StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateLazy.runStateT (a -> StateT s m b
f a
a) s
s'
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, NFData s) => MonadParallelUnion (StateStrict.StateT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
StateT s m a -> (a -> StateT s m b) -> StateT s m b
parBindUnion (StateStrict.StateT s -> m (a, s)
x) a -> StateT s m b
f = (s -> m (b, s)) -> StateT s m b
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateStrict.StateT ((s -> m (b, s)) -> StateT s m b)
-> (s -> m (b, s)) -> StateT s m b
forall a b. (a -> b) -> a -> b
$ \s
s ->
s -> m (a, s)
x s
s m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(a
a, s
s') -> StateT s m b -> s -> m (b, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
StateStrict.runStateT (a -> StateT s m b
f a
a) s
s'
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterLazy.WriterT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterLazy.WriterT m (a, s)
x) a -> WriterT s m b
f =
m (b, s) -> WriterT s m b
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterLazy.WriterT (m (b, s) -> WriterT s m b) -> m (b, s) -> WriterT s m b
forall a b. (a -> b) -> a -> b
$
m (a, s)
x m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(a
a, s
w) ->
WriterT s m b -> m (b, s)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterLazy.runWriterT (a -> WriterT s m b
f a
a) m (b, s) -> ((b, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(b
b, s
w') -> (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w s -> s -> s
forall a. Semigroup a => a -> a -> a
<> s
w')
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable s, Monoid s, NFData s) => MonadParallelUnion (WriterStrict.WriterT s m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
WriterT s m a -> (a -> WriterT s m b) -> WriterT s m b
parBindUnion (WriterStrict.WriterT m (a, s)
x) a -> WriterT s m b
f =
m (b, s) -> WriterT s m b
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterStrict.WriterT (m (b, s) -> WriterT s m b) -> m (b, s) -> WriterT s m b
forall a b. (a -> b) -> a -> b
$
m (a, s)
x m (a, s) -> ((a, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(a
a, s
w) ->
WriterT s m b -> m (b, s)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
WriterStrict.runWriterT (a -> WriterT s m b
f a
a) m (b, s) -> ((b, s) -> m (b, s)) -> m (b, s)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(b
b, s
w') -> (b, s) -> m (b, s)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
w s -> s -> s
forall a. Semigroup a => a -> a -> a
<> s
w')
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m, Mergeable a, NFData a) => MonadParallelUnion (ReaderT a m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
ReaderT a m a -> (a -> ReaderT a m b) -> ReaderT a m b
parBindUnion (ReaderT a -> m a
x) a -> ReaderT a m b
f = (a -> m b) -> ReaderT a m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((a -> m b) -> ReaderT a m b) -> (a -> m b) -> ReaderT a m b
forall a b. (a -> b) -> a -> b
$ \a
a ->
a -> m a
x a
a m a -> (a -> m b) -> m b
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \a
a' -> ReaderT a m b -> a -> m b
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (a -> ReaderT a m b
f a
a') a
a
{-# INLINE parBindUnion #-}
instance (MonadParallelUnion m) => MonadParallelUnion (IdentityT m) where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
IdentityT m a -> (a -> IdentityT m b) -> IdentityT m b
parBindUnion (IdentityT m a
x) a -> IdentityT m b
f = m b -> IdentityT m b
forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT (m b -> IdentityT m b) -> m b -> IdentityT m b
forall a b. (a -> b) -> a -> b
$ m a
x m a -> (a -> m b) -> m b
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` (m b -> m b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m b -> m b) -> (a -> m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdentityT m b -> m b
forall {k} (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT (IdentityT m b -> m b) -> (a -> IdentityT m b) -> a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> IdentityT m b
f)
{-# INLINE parBindUnion #-}
instance
(MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
MonadParallelUnion (RWSStrict.RWST r w s m)
where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = (r -> s -> m (b, s, w)) -> RWST r w s m b
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSStrict.RWST ((r -> s -> m (b, s, w)) -> RWST r w s m b)
-> (r -> s -> m (b, s, w)) -> RWST r w s m b
forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST RWST r w s m a
m r
r s
s m (a, s, w) -> ((a, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(a
a, s
s', w
w) ->
RWST r w s m b -> r -> s -> m (b, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSStrict.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' m (b, s, w) -> ((b, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
(b
b, s
s'', w
w') -> (b, s, w) -> m (b, s, w)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
w')
{-# INLINE parBindUnion #-}
instance
(MonadParallelUnion m, Mergeable s, Mergeable r, Mergeable w, Monoid w, NFData r, NFData w, NFData s) =>
MonadParallelUnion (RWSLazy.RWST r w s m)
where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
RWST r w s m a -> (a -> RWST r w s m b) -> RWST r w s m b
parBindUnion RWST r w s m a
m a -> RWST r w s m b
k = (r -> s -> m (b, s, w)) -> RWST r w s m b
forall r w s (m :: * -> *) a.
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWSLazy.RWST ((r -> s -> m (b, s, w)) -> RWST r w s m b)
-> (r -> s -> m (b, s, w)) -> RWST r w s m b
forall a b. (a -> b) -> a -> b
$ \r
r s
s ->
RWST r w s m a -> r -> s -> m (a, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST RWST r w s m a
m r
r s
s m (a, s, w) -> ((a, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(a
a, s
s', w
w) ->
RWST r w s m b -> r -> s -> m (b, s, w)
forall r w s (m :: * -> *) a.
RWST r w s m a -> r -> s -> m (a, s, w)
RWSLazy.runRWST (a -> RWST r w s m b
k a
a) r
r s
s' m (b, s, w) -> ((b, s, w) -> m (b, s, w)) -> m (b, s, w)
forall b a. (Mergeable b, NFData b) => m a -> (a -> m b) -> m b
forall (m :: * -> *) b a.
(MonadParallelUnion m, Mergeable b, NFData b) =>
m a -> (a -> m b) -> m b
`parBindUnion` \case
~(b
b, s
s'', w
w') -> (b, s, w) -> m (b, s, w)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (b
b, s
s'', w
w w -> w -> w
forall a. Semigroup a => a -> a -> a
<> w
w')
{-# INLINE parBindUnion #-}
parBindUnion'' :: (Mergeable b, NFData b) => UnionBase a -> (a -> Union b) -> Union b
parBindUnion'' :: forall b a.
(Mergeable b, NFData b) =>
UnionBase a -> (a -> Union b) -> Union b
parBindUnion'' (UnionSingle a
a) a -> Union b
f = Union b -> Union b
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (Union b -> Union b) -> Union b -> Union b
forall a b. (a -> b) -> a -> b
$ a -> Union b
f a
a
parBindUnion'' UnionBase a
u a -> Union b
f = UnionBase a -> (a -> Union b) -> Union b
forall b a.
(Mergeable b, NFData b) =>
UnionBase a -> (a -> Union b) -> Union b
parBindUnion' UnionBase a
u a -> Union b
f
parBindUnion' :: (Mergeable b, NFData b) => UnionBase a -> (a -> Union b) -> Union b
parBindUnion' :: forall b a.
(Mergeable b, NFData b) =>
UnionBase a -> (a -> Union b) -> Union b
parBindUnion' (UnionSingle a
a') a -> Union b
f' = a -> Union b
f' a
a'
parBindUnion' (UnionIf a
_ Bool
_ SymBool
cond UnionBase a
ifTrue UnionBase a
ifFalse) a -> Union b
f' = Eval (Union b) -> Union b
forall a. Eval a -> a
runEval (Eval (Union b) -> Union b) -> Eval (Union b) -> Union b
forall a b. (a -> b) -> a -> b
$ do
l <- Strategy (Union b)
forall a. Strategy a
rpar Strategy (Union b) -> Strategy (Union b)
forall a b. (a -> b) -> a -> b
$ Union b -> Union b
forall a. NFData a => a -> a
force (Union b -> Union b) -> Union b -> Union b
forall a b. (a -> b) -> a -> b
$ UnionBase a -> (a -> Union b) -> Union b
forall b a.
(Mergeable b, NFData b) =>
UnionBase a -> (a -> Union b) -> Union b
parBindUnion' UnionBase a
ifTrue a -> Union b
f'
r <- rpar $ force $ parBindUnion' ifFalse f'
l' <- rseq l
r' <- rseq r
rseq $ mrgIf cond l' r'
{-# INLINE parBindUnion' #-}
instance MonadParallelUnion Union where
parBindUnion :: forall b a.
(Mergeable b, NFData b) =>
Union a -> (a -> Union b) -> Union b
parBindUnion = UnionBase a -> (a -> Union b) -> Union b
forall b a.
(Mergeable b, NFData b) =>
UnionBase a -> (a -> Union b) -> Union b
parBindUnion'' (UnionBase a -> (a -> Union b) -> Union b)
-> (Union a -> UnionBase a) -> Union a -> (a -> Union b) -> Union b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Union a -> UnionBase a
forall a. Union a -> UnionBase a
unionBase
{-# INLINE parBindUnion #-}