{-# LANGUAGE CPP #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      :   Grisette.Internal.Internal.Impl.Core.Data.UnionBase
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Internal.Impl.Core.Data.UnionBase
  (
  )
where

#if MIN_VERSION_prettyprinter(1,7,0)
import Prettyprinter (align, group, nest, vsep)
#else
import Data.Text.Prettyprint.Doc (align, group, nest, vsep)
#endif

import Control.DeepSeq (NFData (rnf), NFData1 (liftRnf), rnf1)
import qualified Data.Binary as Binary
import Data.Bytes.Get (MonadGet (getWord8))
import Data.Bytes.Put (MonadPut (putWord8))
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Functor.Classes
  ( Eq1 (liftEq),
    Show1 (liftShowsPrec),
    showsPrec1,
    showsUnaryWith,
  )
import Data.Hashable (Hashable (hashWithSalt))
import Data.Hashable.Lifted (Hashable1 (liftHashWithSalt), hashWithSalt1)
import qualified Data.Serialize as Cereal
import Grisette.Internal.Core.Data.Class.Mergeable
  ( Mergeable (rootStrategy),
  )
import Grisette.Internal.Core.Data.Class.PPrint
  ( PPrint (pformatPrec),
    PPrint1 (liftPFormatPrec),
    condEnclose,
    pformatPrec1,
  )
import Grisette.Internal.Internal.Decl.Core.Data.UnionBase
  ( UnionBase (UnionIf, UnionSingle),
    ifWithStrategy,
  )
import Grisette.Internal.SymPrim.AllSyms
  ( AllSyms (allSymsS),
    AllSyms1 (liftAllSymsS),
    SomeSym (SomeSym),
  )

instance Eq1 UnionBase where
  liftEq :: forall a b. (a -> b -> Bool) -> UnionBase a -> UnionBase b -> Bool
liftEq a -> b -> Bool
e (UnionSingle a
a) (UnionSingle b
b) = a -> b -> Bool
e a
a b
b
  liftEq a -> b -> Bool
e (UnionIf a
l1 Bool
i1 SymBool
c1 UnionBase a
t1 UnionBase a
f1) (UnionIf b
l2 Bool
i2 SymBool
c2 UnionBase b
t2 UnionBase b
f2) =
    a -> b -> Bool
e a
l1 b
l2 Bool -> Bool -> Bool
&& Bool
i1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
i2 Bool -> Bool -> Bool
&& SymBool
c1 SymBool -> SymBool -> Bool
forall a. Eq a => a -> a -> Bool
== SymBool
c2 Bool -> Bool -> Bool
&& (a -> b -> Bool) -> UnionBase a -> UnionBase b -> Bool
forall a b. (a -> b -> Bool) -> UnionBase a -> UnionBase b -> Bool
forall (f :: * -> *) a b.
Eq1 f =>
(a -> b -> Bool) -> f a -> f b -> Bool
liftEq a -> b -> Bool
e UnionBase a
t1 UnionBase b
t2 Bool -> Bool -> Bool
&& (a -> b -> Bool) -> UnionBase a -> UnionBase b -> Bool
forall a b. (a -> b -> Bool) -> UnionBase a -> UnionBase b -> Bool
forall (f :: * -> *) a b.
Eq1 f =>
(a -> b -> Bool) -> f a -> f b -> Bool
liftEq a -> b -> Bool
e UnionBase a
f1 UnionBase b
f2
  liftEq a -> b -> Bool
_ UnionBase a
_ UnionBase b
_ = Bool
False

instance (NFData a) => NFData (UnionBase a) where
  rnf :: UnionBase a -> ()
rnf = UnionBase a -> ()
forall (f :: * -> *) a. (NFData1 f, NFData a) => f a -> ()
rnf1

instance NFData1 UnionBase where
  liftRnf :: forall a. (a -> ()) -> UnionBase a -> ()
liftRnf a -> ()
_a (UnionSingle a
a) = a -> ()
_a a
a
  liftRnf a -> ()
_a (UnionIf a
a Bool
bo SymBool
b UnionBase a
l UnionBase a
r) =
    a -> ()
_a a
a () -> () -> ()
forall a b. a -> b -> b
`seq`
      Bool -> ()
forall a. NFData a => a -> ()
rnf Bool
bo () -> () -> ()
forall a b. a -> b -> b
`seq`
        SymBool -> ()
forall a. NFData a => a -> ()
rnf SymBool
b () -> () -> ()
forall a b. a -> b -> b
`seq`
          (a -> ()) -> UnionBase a -> ()
forall a. (a -> ()) -> UnionBase a -> ()
forall (f :: * -> *) a. NFData1 f => (a -> ()) -> f a -> ()
liftRnf a -> ()
_a UnionBase a
l () -> () -> ()
forall a b. a -> b -> b
`seq`
            (a -> ()) -> UnionBase a -> ()
forall a. (a -> ()) -> UnionBase a -> ()
forall (f :: * -> *) a. NFData1 f => (a -> ()) -> f a -> ()
liftRnf a -> ()
_a UnionBase a
r

instance (Mergeable a, Serial a) => Serial (UnionBase a) where
  serialize :: forall (m :: * -> *). MonadPut m => UnionBase a -> m ()
serialize (UnionSingle a
a) = Word8 -> m ()
forall (m :: * -> *). MonadPut m => Word8 -> m ()
putWord8 Word8
0 m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => a -> m ()
serialize a
a
  serialize (UnionIf a
_ Bool
_ SymBool
c UnionBase a
a UnionBase a
b) =
    Word8 -> m ()
forall (m :: * -> *). MonadPut m => Word8 -> m ()
putWord8 Word8
1 m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SymBool -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SymBool -> m ()
serialize SymBool
c m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> UnionBase a -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => UnionBase a -> m ()
serialize UnionBase a
a m () -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> UnionBase a -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => UnionBase a -> m ()
serialize UnionBase a
b
  deserialize :: forall (m :: * -> *). MonadGet m => m (UnionBase a)
deserialize = do
    tag <- m Word8
forall (m :: * -> *). MonadGet m => m Word8
getWord8
    case tag of
      Word8
0 -> a -> UnionBase a
forall a. a -> UnionBase a
UnionSingle (a -> UnionBase a) -> m a -> m (UnionBase a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m a
deserialize
      Word8
1 ->
        MergingStrategy a
-> SymBool -> UnionBase a -> UnionBase a -> UnionBase a
forall a.
MergingStrategy a
-> SymBool -> UnionBase a -> UnionBase a -> UnionBase a
ifWithStrategy MergingStrategy a
forall a. Mergeable a => MergingStrategy a
rootStrategy
          (SymBool -> UnionBase a -> UnionBase a -> UnionBase a)
-> m SymBool -> m (UnionBase a -> UnionBase a -> UnionBase a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m SymBool
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SymBool
deserialize
          m (UnionBase a -> UnionBase a -> UnionBase a)
-> m (UnionBase a) -> m (UnionBase a -> UnionBase a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m (UnionBase a)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (UnionBase a)
deserialize
          m (UnionBase a -> UnionBase a)
-> m (UnionBase a) -> m (UnionBase a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m (UnionBase a)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (UnionBase a)
deserialize
      Word8
_ -> String -> m (UnionBase a)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid tag"

instance (Mergeable a, Serial a) => Cereal.Serialize (UnionBase a) where
  put :: Putter (UnionBase a)
put = Putter (UnionBase a)
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => UnionBase a -> m ()
serialize
  get :: Get (UnionBase a)
get = Get (UnionBase a)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (UnionBase a)
deserialize

instance (Mergeable a, Serial a) => Binary.Binary (UnionBase a) where
  put :: UnionBase a -> Put
put = UnionBase a -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => UnionBase a -> m ()
serialize
  get :: Get (UnionBase a)
get = Get (UnionBase a)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (UnionBase a)
deserialize

instance Show1 UnionBase where
  liftShowsPrec :: forall a.
(Int -> a -> ShowS)
-> ([a] -> ShowS) -> Int -> UnionBase a -> ShowS
liftShowsPrec Int -> a -> ShowS
sp [a] -> ShowS
_ Int
i (UnionSingle a
a) = (Int -> a -> ShowS) -> String -> Int -> a -> ShowS
forall a. (Int -> a -> ShowS) -> String -> Int -> a -> ShowS
showsUnaryWith Int -> a -> ShowS
sp String
"Single" Int
i a
a
  liftShowsPrec Int -> a -> ShowS
sp [a] -> ShowS
sl Int
i (UnionIf a
_ Bool
_ SymBool
cond UnionBase a
t UnionBase a
f) =
    Bool -> ShowS -> ShowS
showParen (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
      String -> ShowS
showString String
"If"
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ShowS
showChar Char
' '
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SymBool -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 SymBool
cond
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ShowS
showChar Char
' '
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> UnionBase a -> ShowS
sp1 Int
11 UnionBase a
t
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ShowS
showChar Char
' '
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> UnionBase a -> ShowS
sp1 Int
11 UnionBase a
f
    where
      sp1 :: Int -> UnionBase a -> ShowS
sp1 = (Int -> a -> ShowS)
-> ([a] -> ShowS) -> Int -> UnionBase a -> ShowS
forall a.
(Int -> a -> ShowS)
-> ([a] -> ShowS) -> Int -> UnionBase a -> ShowS
forall (f :: * -> *) a.
Show1 f =>
(Int -> a -> ShowS) -> ([a] -> ShowS) -> Int -> f a -> ShowS
liftShowsPrec Int -> a -> ShowS
sp [a] -> ShowS
sl

instance (Show a) => Show (UnionBase a) where
  showsPrec :: Int -> UnionBase a -> ShowS
showsPrec = Int -> UnionBase a -> ShowS
forall (f :: * -> *) a. (Show1 f, Show a) => Int -> f a -> ShowS
showsPrec1

instance (PPrint a) => PPrint (UnionBase a) where
  pformatPrec :: forall ann. Int -> UnionBase a -> Doc ann
pformatPrec = Int -> UnionBase a -> Doc ann
forall (f :: * -> *) a ann.
(PPrint1 f, PPrint a) =>
Int -> f a -> Doc ann
pformatPrec1

instance PPrint1 UnionBase where
  liftPFormatPrec :: forall a ann.
(Int -> a -> Doc ann)
-> ([a] -> Doc ann) -> Int -> UnionBase a -> Doc ann
liftPFormatPrec Int -> a -> Doc ann
fa [a] -> Doc ann
_ Int
n (UnionSingle a
a) = Int -> a -> Doc ann
fa Int
n a
a
  liftPFormatPrec Int -> a -> Doc ann
fa [a] -> Doc ann
fl Int
n (UnionIf a
_ Bool
_ SymBool
cond UnionBase a
t UnionBase a
f) =
    Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
group (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
      Bool -> Doc ann -> Doc ann -> Doc ann -> Doc ann
forall ann. Bool -> Doc ann -> Doc ann -> Doc ann -> Doc ann
condEnclose (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) Doc ann
"(" Doc ann
")" (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
        Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
          Int -> Doc ann -> Doc ann
forall ann. Int -> Doc ann -> Doc ann
nest Int
2 (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$
            [Doc ann] -> Doc ann
forall ann. [Doc ann] -> Doc ann
vsep
              [ Doc ann
"If",
                Int -> SymBool -> Doc ann
forall ann. Int -> SymBool -> Doc ann
forall a ann. PPrint a => Int -> a -> Doc ann
pformatPrec Int
11 SymBool
cond,
                (Int -> a -> Doc ann)
-> ([a] -> Doc ann) -> Int -> UnionBase a -> Doc ann
forall a ann.
(Int -> a -> Doc ann)
-> ([a] -> Doc ann) -> Int -> UnionBase a -> Doc ann
forall (f :: * -> *) a ann.
PPrint1 f =>
(Int -> a -> Doc ann) -> ([a] -> Doc ann) -> Int -> f a -> Doc ann
liftPFormatPrec Int -> a -> Doc ann
fa [a] -> Doc ann
fl Int
11 UnionBase a
t,
                (Int -> a -> Doc ann)
-> ([a] -> Doc ann) -> Int -> UnionBase a -> Doc ann
forall a ann.
(Int -> a -> Doc ann)
-> ([a] -> Doc ann) -> Int -> UnionBase a -> Doc ann
forall (f :: * -> *) a ann.
PPrint1 f =>
(Int -> a -> Doc ann) -> ([a] -> Doc ann) -> Int -> f a -> Doc ann
liftPFormatPrec Int -> a -> Doc ann
fa [a] -> Doc ann
fl Int
11 UnionBase a
f
              ]

instance (Hashable a) => Hashable (UnionBase a) where
  hashWithSalt :: Int -> UnionBase a -> Int
hashWithSalt = Int -> UnionBase a -> Int
forall (f :: * -> *) a.
(Hashable1 f, Hashable a) =>
Int -> f a -> Int
hashWithSalt1

instance Hashable1 UnionBase where
  liftHashWithSalt :: forall a. (Int -> a -> Int) -> Int -> UnionBase a -> Int
liftHashWithSalt Int -> a -> Int
f Int
s (UnionSingle a
a) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
0 :: Int) Int -> a -> Int
`f` a
a
  liftHashWithSalt Int -> a -> Int
f Int
s (UnionIf a
_ Bool
_ SymBool
c UnionBase a
l UnionBase a
r) =
    let p :: Int -> UnionBase a -> Int
p = (Int -> a -> Int) -> Int -> UnionBase a -> Int
forall a. (Int -> a -> Int) -> Int -> UnionBase a -> Int
forall (t :: * -> *) a.
Hashable1 t =>
(Int -> a -> Int) -> Int -> t a -> Int
liftHashWithSalt Int -> a -> Int
f
     in (Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
1 :: Int) Int -> SymBool -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` SymBool
c) Int -> UnionBase a -> Int
`p` UnionBase a
l Int -> UnionBase a -> Int
`p` UnionBase a
r

instance (AllSyms a) => AllSyms (UnionBase a) where
  allSymsS :: UnionBase a -> [SomeSym] -> [SomeSym]
allSymsS (UnionSingle a
v) = a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS a
v
  allSymsS (UnionIf a
_ Bool
_ SymBool
c UnionBase a
t UnionBase a
f) = \[SomeSym]
l -> SymBool -> SomeSym
forall con sym. LinkedRep con sym => sym -> SomeSym
SomeSym SymBool
c SomeSym -> [SomeSym] -> [SomeSym]
forall a. a -> [a] -> [a]
: (UnionBase a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS UnionBase a
t ([SomeSym] -> [SomeSym])
-> ([SomeSym] -> [SomeSym]) -> [SomeSym] -> [SomeSym]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnionBase a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS UnionBase a
f ([SomeSym] -> [SomeSym]) -> [SomeSym] -> [SomeSym]
forall a b. (a -> b) -> a -> b
$ [SomeSym]
l)

instance AllSyms1 UnionBase where
  liftAllSymsS :: forall a.
(a -> [SomeSym] -> [SomeSym])
-> UnionBase a -> [SomeSym] -> [SomeSym]
liftAllSymsS a -> [SomeSym] -> [SomeSym]
fa (UnionSingle a
v) = a -> [SomeSym] -> [SomeSym]
fa a
v
  liftAllSymsS a -> [SomeSym] -> [SomeSym]
fa (UnionIf a
_ Bool
_ SymBool
c UnionBase a
t UnionBase a
f) =
    \[SomeSym]
l -> SymBool -> SomeSym
forall con sym. LinkedRep con sym => sym -> SomeSym
SomeSym SymBool
c SomeSym -> [SomeSym] -> [SomeSym]
forall a. a -> [a] -> [a]
: ((a -> [SomeSym] -> [SomeSym])
-> UnionBase a -> [SomeSym] -> [SomeSym]
forall a.
(a -> [SomeSym] -> [SomeSym])
-> UnionBase a -> [SomeSym] -> [SomeSym]
forall (f :: * -> *) a.
AllSyms1 f =>
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
liftAllSymsS a -> [SomeSym] -> [SomeSym]
fa UnionBase a
t ([SomeSym] -> [SomeSym])
-> ([SomeSym] -> [SomeSym]) -> [SomeSym] -> [SomeSym]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> [SomeSym] -> [SomeSym])
-> UnionBase a -> [SomeSym] -> [SomeSym]
forall a.
(a -> [SomeSym] -> [SomeSym])
-> UnionBase a -> [SomeSym] -> [SomeSym]
forall (f :: * -> *) a.
AllSyms1 f =>
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
liftAllSymsS a -> [SomeSym] -> [SomeSym]
fa UnionBase a
f ([SomeSym] -> [SomeSym]) -> [SomeSym] -> [SomeSym]
forall a b. (a -> b) -> a -> b
$ [SomeSym]
l)