{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.SymBool
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.SymBool (SymBool (SymBool)) where

import Control.DeepSeq (NFData)
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Hashable (Hashable (hashWithSalt))
import qualified Data.Serialize as Cereal
import Data.String (IsString (fromString))
import GHC.Generics (Generic)
import Grisette.Internal.Core.Data.Class.Function (Apply (FunType, apply))
import Grisette.Internal.Core.Data.Class.Solvable
  ( Solvable (con, conView, ssym, sym),
  )
import Grisette.Internal.Internal.Decl.SymPrim.AllSyms
  ( AllSyms (allSymsS),
    SomeSym (SomeSym),
  )
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( ConRep (ConType),
    LinkedRep (underlyingTerm, wrapTerm),
    SymRep (SymType),
    Term,
    conTerm,
    pformatTerm,
    symTerm,
    typedConstantSymbol,
    pattern ConTerm,
  )
import Language.Haskell.TH.Syntax (Lift)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Data.Proxy

-- | Symbolic Boolean type.
--
-- >>> "a" :: SymBool
-- a
-- >>> "a" .&& "b" :: SymBool
-- (&& a b)
--
-- More operations are available. Please refer to "Grisette.Core#g:symops" for
-- more information.
newtype SymBool = SymBool {SymBool -> Term Bool
underlyingBoolTerm :: Term Bool}
  deriving ((forall (m :: * -> *). Quote m => SymBool -> m Exp)
-> (forall (m :: * -> *). Quote m => SymBool -> Code m SymBool)
-> Lift SymBool
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => SymBool -> m Exp
forall (m :: * -> *). Quote m => SymBool -> Code m SymBool
$clift :: forall (m :: * -> *). Quote m => SymBool -> m Exp
lift :: forall (m :: * -> *). Quote m => SymBool -> m Exp
$cliftTyped :: forall (m :: * -> *). Quote m => SymBool -> Code m SymBool
liftTyped :: forall (m :: * -> *). Quote m => SymBool -> Code m SymBool
Lift, SymBool -> ()
(SymBool -> ()) -> NFData SymBool
forall a. (a -> ()) -> NFData a
$crnf :: SymBool -> ()
rnf :: SymBool -> ()
NFData, (forall x. SymBool -> Rep SymBool x)
-> (forall x. Rep SymBool x -> SymBool) -> Generic SymBool
forall x. Rep SymBool x -> SymBool
forall x. SymBool -> Rep SymBool x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SymBool -> Rep SymBool x
from :: forall x. SymBool -> Rep SymBool x
$cto :: forall x. Rep SymBool x -> SymBool
to :: forall x. Rep SymBool x -> SymBool
Generic)

instance ConRep SymBool where
  type ConType SymBool = Bool

instance SymRep Bool where
  type SymType Bool = SymBool

instance LinkedRep Bool SymBool where
  underlyingTerm :: SymBool -> Term Bool
underlyingTerm (SymBool Term Bool
a) = Term Bool
a
  wrapTerm :: Term Bool -> SymBool
wrapTerm = Term Bool -> SymBool
SymBool

instance Apply SymBool where
  type FunType SymBool = SymBool
  apply :: SymBool -> FunType SymBool
apply = SymBool -> FunType SymBool
SymBool -> SymBool
forall a. a -> a
id

instance Eq SymBool where
  SymBool Term Bool
l == :: SymBool -> SymBool -> Bool
== SymBool Term Bool
r = Term Bool
l Term Bool -> Term Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Term Bool
r

instance Hashable SymBool where
  hashWithSalt :: Int -> SymBool -> Int
hashWithSalt Int
s (SymBool Term Bool
v) = Int
s Int -> Term Bool -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Term Bool
v

instance Solvable Bool SymBool where
  con :: Bool -> SymBool
con = Term Bool -> SymBool
SymBool (Term Bool -> SymBool) -> (Bool -> Term Bool) -> Bool -> SymBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Term Bool
forall t. SupportedPrim t => t -> Term t
conTerm
  sym :: Symbol -> SymBool
sym = Term Bool -> SymBool
SymBool (Term Bool -> SymBool)
-> (Symbol -> Term Bool) -> Symbol -> SymBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedSymbol 'ConstantKind Bool -> Term Bool
forall (knd :: SymbolKind) t. TypedSymbol knd t -> Term t
symTerm (TypedSymbol 'ConstantKind Bool -> Term Bool)
-> (Symbol -> TypedSymbol 'ConstantKind Bool)
-> Symbol
-> Term Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Symbol -> TypedSymbol 'ConstantKind Bool
forall t.
SupportedNonFuncPrim t =>
Symbol -> TypedSymbol 'ConstantKind t
typedConstantSymbol
  conView :: SymBool -> Maybe Bool
conView (SymBool (ConTerm Bool
t)) = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
t
  conView SymBool
_ = Maybe Bool
forall a. Maybe a
Nothing

instance IsString SymBool where
  fromString :: String -> SymBool
fromString = Identifier -> SymBool
forall c t. Solvable c t => Identifier -> t
ssym (Identifier -> SymBool)
-> (String -> Identifier) -> String -> SymBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Identifier
forall a. IsString a => String -> a
fromString

instance Show SymBool where
  show :: SymBool -> String
show (SymBool Term Bool
t) = Term Bool -> String
forall t. Term t -> String
pformatTerm Term Bool
t

instance AllSyms SymBool where
  allSymsS :: SymBool -> [SomeSym] -> [SomeSym]
allSymsS SymBool
v = (SymBool -> SomeSym
forall con sym. LinkedRep con sym => sym -> SomeSym
SomeSym SymBool
v SomeSym -> [SomeSym] -> [SomeSym]
forall a. a -> [a] -> [a]
:)

instance Serial SymBool where
  serialize :: forall (m :: * -> *). MonadPut m => SymBool -> m ()
serialize = Term Bool -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Term Bool -> m ()
serialize (Term Bool -> m ()) -> (SymBool -> Term Bool) -> SymBool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymBool -> Term Bool
underlyingBoolTerm
  deserialize :: forall (m :: * -> *). MonadGet m => m SymBool
deserialize = Term Bool -> SymBool
SymBool (Term Bool -> SymBool) -> m (Term Bool) -> m SymBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Term Bool)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (Term Bool)
deserialize

instance Cereal.Serialize SymBool where
  put :: Putter SymBool
put = Putter SymBool
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SymBool -> m ()
serialize
  get :: Get SymBool
get = Get SymBool
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SymBool
deserialize

instance Binary.Binary SymBool where
  put :: SymBool -> Put
put = SymBool -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SymBool -> m ()
serialize
  get :: Get SymBool
get = Get SymBool
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SymBool
deserialize