{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.SymBool
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.SymBool (SymBool (SymBool), SymBoolKey) where

import Control.DeepSeq (NFData)
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Hashable (Hashable (hashWithSalt))
import qualified Data.Serialize as Cereal
import Data.String (IsString (fromString))
import GHC.Generics (Generic)
import Grisette.Internal.Core.Data.Class.AsKey
  ( AsKey,
    KeyEq (keyEq),
    KeyHashable (keyHashWithSalt),
    shouldUseAsKeyHasSymbolicVersionError,
  )
import Grisette.Internal.Core.Data.Class.Function (Apply (FunType, apply))
import Grisette.Internal.Core.Data.Class.Solvable
  ( Solvable (con, conView, ssym, sym),
  )
import Grisette.Internal.Internal.Decl.SymPrim.AllSyms
  ( AllSyms (allSymsS),
    SomeSym (SomeSym),
  )
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( ConRep (ConType),
    LinkedRep (underlyingTerm, wrapTerm),
    SymRep (SymType),
    Term,
    conTerm,
    pformatTerm,
    symTerm,
    typedConstantSymbol,
    pattern ConTerm,
  )
import Language.Haskell.TH.Syntax (Lift)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Data.Proxy

-- | Symbolic Boolean type.
--
-- >>> "a" :: SymBool
-- a
-- >>> "a" .&& "b" :: SymBool
-- (&& a b)
--
-- More operations are available. Please refer to "Grisette.Core#g:symops" for
-- more information.
newtype SymBool = SymBool {SymBool -> Term Bool
underlyingBoolTerm :: Term Bool}
  deriving ((forall (m :: * -> *). Quote m => SymBool -> m Exp)
-> (forall (m :: * -> *). Quote m => SymBool -> Code m SymBool)
-> Lift SymBool
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => SymBool -> m Exp
forall (m :: * -> *). Quote m => SymBool -> Code m SymBool
$clift :: forall (m :: * -> *). Quote m => SymBool -> m Exp
lift :: forall (m :: * -> *). Quote m => SymBool -> m Exp
$cliftTyped :: forall (m :: * -> *). Quote m => SymBool -> Code m SymBool
liftTyped :: forall (m :: * -> *). Quote m => SymBool -> Code m SymBool
Lift, SymBool -> ()
(SymBool -> ()) -> NFData SymBool
forall a. (a -> ()) -> NFData a
$crnf :: SymBool -> ()
rnf :: SymBool -> ()
NFData, (forall x. SymBool -> Rep SymBool x)
-> (forall x. Rep SymBool x -> SymBool) -> Generic SymBool
forall x. Rep SymBool x -> SymBool
forall x. SymBool -> Rep SymBool x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SymBool -> Rep SymBool x
from :: forall x. SymBool -> Rep SymBool x
$cto :: forall x. Rep SymBool x -> SymBool
to :: forall x. Rep SymBool x -> SymBool
Generic)

-- | t'SymBool' type with identity equality.
type SymBoolKey = AsKey SymBool

instance ConRep SymBool where
  type ConType SymBool = Bool

instance SymRep Bool where
  type SymType Bool = SymBool

instance LinkedRep Bool SymBool where
  underlyingTerm :: SymBool -> Term Bool
underlyingTerm (SymBool Term Bool
a) = Term Bool
a
  wrapTerm :: Term Bool -> SymBool
wrapTerm = Term Bool -> SymBool
SymBool

instance Apply SymBool where
  type FunType SymBool = SymBool
  apply :: SymBool -> FunType SymBool
apply = SymBool -> FunType SymBool
SymBool -> SymBool
forall a. a -> a
id

-- | This will crash the program.
--
-- 'SymBool' cannot be compared concretely.
--
-- If you want to use the type as keys in hash maps based on term equality, say
-- memo table, you should use @'AsKey' 'SymBool'@ instead.
--
-- If you want symbolic version of the equality operator, use
-- t'Grisette.Core.SymEq' instead.
instance Eq SymBool where
  == :: SymBool -> SymBool -> Bool
(==) = String -> String -> String -> SymBool -> SymBool -> Bool
forall a. HasCallStack => String -> String -> String -> a
shouldUseAsKeyHasSymbolicVersionError String
"SymBool" String
"(==)" String
"(.==)"

instance KeyEq SymBool where
  keyEq :: SymBool -> SymBool -> Bool
keyEq (SymBool Term Bool
l) (SymBool Term Bool
r) = Term Bool
l Term Bool -> Term Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Term Bool
r

instance KeyHashable SymBool where
  keyHashWithSalt :: Int -> SymBool -> Int
keyHashWithSalt Int
s (SymBool Term Bool
v) = Int
s Int -> Term Bool -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Term Bool
v

instance Solvable Bool SymBool where
  con :: Bool -> SymBool
con = Term Bool -> SymBool
SymBool (Term Bool -> SymBool) -> (Bool -> Term Bool) -> Bool -> SymBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Term Bool
forall t. SupportedPrim t => t -> Term t
conTerm
  sym :: Symbol -> SymBool
sym = Term Bool -> SymBool
SymBool (Term Bool -> SymBool)
-> (Symbol -> Term Bool) -> Symbol -> SymBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedSymbol 'ConstantKind Bool -> Term Bool
forall (knd :: SymbolKind) t. TypedSymbol knd t -> Term t
symTerm (TypedSymbol 'ConstantKind Bool -> Term Bool)
-> (Symbol -> TypedSymbol 'ConstantKind Bool)
-> Symbol
-> Term Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Symbol -> TypedSymbol 'ConstantKind Bool
forall t.
SupportedNonFuncPrim t =>
Symbol -> TypedSymbol 'ConstantKind t
typedConstantSymbol
  conView :: SymBool -> Maybe Bool
conView (SymBool (ConTerm Bool
t)) = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
t
  conView SymBool
_ = Maybe Bool
forall a. Maybe a
Nothing

instance IsString SymBool where
  fromString :: String -> SymBool
fromString = Identifier -> SymBool
forall c t. Solvable c t => Identifier -> t
ssym (Identifier -> SymBool)
-> (String -> Identifier) -> String -> SymBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Identifier
forall a. IsString a => String -> a
fromString

instance Show SymBool where
  show :: SymBool -> String
show (SymBool Term Bool
t) = Term Bool -> String
forall t. Term t -> String
pformatTerm Term Bool
t

instance AllSyms SymBool where
  allSymsS :: SymBool -> [SomeSym] -> [SomeSym]
allSymsS SymBool
v = (SymBool -> SomeSym
forall con sym. LinkedRep con sym => sym -> SomeSym
SomeSym SymBool
v SomeSym -> [SomeSym] -> [SomeSym]
forall a. a -> [a] -> [a]
:)

instance Serial SymBool where
  serialize :: forall (m :: * -> *). MonadPut m => SymBool -> m ()
serialize = Term Bool -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Term Bool -> m ()
serialize (Term Bool -> m ()) -> (SymBool -> Term Bool) -> SymBool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymBool -> Term Bool
underlyingBoolTerm
  deserialize :: forall (m :: * -> *). MonadGet m => m SymBool
deserialize = Term Bool -> SymBool
SymBool (Term Bool -> SymBool) -> m (Term Bool) -> m SymBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Term Bool)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (Term Bool)
deserialize

instance Cereal.Serialize SymBool where
  put :: Putter SymBool
put = Putter SymBool
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SymBool -> m ()
serialize
  get :: Get SymBool
get = Get SymBool
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SymBool
deserialize

instance Binary.Binary SymBool where
  put :: SymBool -> Put
put = SymBool -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SymBool -> m ()
serialize
  get :: Get SymBool
get = Get SymBool
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SymBool
deserialize