{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.EvalSym
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Internal.Impl.Core.Data.Class.EvalSym () where

import Control.Monad.Except (ExceptT)
import Control.Monad.Identity
  ( Identity,
    IdentityT (IdentityT),
  )
import Control.Monad.Trans.Maybe (MaybeT)
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import qualified Data.ByteString as B
import Data.Functor.Compose (Compose (Compose))
import Data.Functor.Const (Const)
import Data.Functor.Product (Product)
import Data.Functor.Sum (Sum)
import qualified Data.HashSet as HS
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Monoid (Alt, Ap)
import Data.Ord (Down)
import Data.Proxy (Proxy)
import Data.Ratio (Ratio, denominator, numerator, (%))
import qualified Data.Text as T
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeNats (KnownNat, type (<=))
import Generics.Deriving
  ( Default (Default),
    Default1 (Default1),
    K1 (K1),
    M1 (M1),
    Par1 (Par1),
    Rec1 (Rec1),
    U1,
    V1,
    (:.:) (Comp1),
    type (:*:),
    type (:+:),
  )
import Generics.Deriving.Instances ()
import qualified Generics.Deriving.Monoid as Monoid
import Grisette.Internal.Core.Control.Exception
  ( AssertionError,
    VerificationConditions,
  )
import Grisette.Internal.Internal.Decl.Core.Data.Class.EvalSym
  ( EvalSym (evalSym),
    EvalSym1 (liftEvalSym),
    EvalSym2,
    evalSym1,
  )
import Grisette.Internal.SymPrim.AlgReal (AlgReal)
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.FP
  ( FP,
    FPRoundingMode,
    NotRepresentableFPError,
    ValidFP,
  )
import Grisette.Internal.SymPrim.GeneralFun (type (-->) (GeneralFun))
import Grisette.Internal.SymPrim.Prim.Model (evalTerm)
import Grisette.Internal.SymPrim.Prim.Term
  ( SymRep (SymType),
    someTypedSymbol,
  )
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal))
import Grisette.Internal.SymPrim.SymBV
  ( SymIntN (SymIntN),
    SymWordN (SymWordN),
  )
import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool))
import Grisette.Internal.SymPrim.SymFP
  ( SymFP (SymFP),
    SymFPRoundingMode (SymFPRoundingMode),
  )
import Grisette.Internal.SymPrim.SymGeneralFun (type (-~>) (SymGeneralFun))
import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger))
import Grisette.Internal.SymPrim.SymTabularFun (type (=~>) (SymTabularFun))
import Grisette.Internal.SymPrim.TabularFun (type (=->) (TabularFun))
import Grisette.Internal.TH.Derivation.Derive (derive)

#define CONCRETE_EVALUATESYM(type) \
instance EvalSym type where \
  evalSym _ _ = id

#define CONCRETE_EVALUATESYM_BV(type) \
instance (KnownNat n, 1 <= n) => EvalSym (type n) where \
  evalSym _ _ = id

#if 1
CONCRETE_EVALUATESYM(Bool)
CONCRETE_EVALUATESYM(Integer)
CONCRETE_EVALUATESYM(Char)
CONCRETE_EVALUATESYM(Int)
CONCRETE_EVALUATESYM(Int8)
CONCRETE_EVALUATESYM(Int16)
CONCRETE_EVALUATESYM(Int32)
CONCRETE_EVALUATESYM(Int64)
CONCRETE_EVALUATESYM(Word)
CONCRETE_EVALUATESYM(Word8)
CONCRETE_EVALUATESYM(Word16)
CONCRETE_EVALUATESYM(Word32)
CONCRETE_EVALUATESYM(Word64)
CONCRETE_EVALUATESYM(Float)
CONCRETE_EVALUATESYM(Double)
CONCRETE_EVALUATESYM(B.ByteString)
CONCRETE_EVALUATESYM(T.Text)
CONCRETE_EVALUATESYM(FPRoundingMode)
CONCRETE_EVALUATESYM(Monoid.All)
CONCRETE_EVALUATESYM(Monoid.Any)
CONCRETE_EVALUATESYM(Ordering)
CONCRETE_EVALUATESYM_BV(IntN)
CONCRETE_EVALUATESYM_BV(WordN)
CONCRETE_EVALUATESYM(AlgReal)
#endif

instance EvalSym (Proxy a) where
  evalSym :: Bool -> Model -> Proxy a -> Proxy a
evalSym Bool
_ Model
_ = Proxy a -> Proxy a
forall a. a -> a
id
  {-# INLINE evalSym #-}

instance EvalSym1 Proxy where
  liftEvalSym :: forall a.
(Bool -> Model -> a -> a) -> Bool -> Model -> Proxy a -> Proxy a
liftEvalSym Bool -> Model -> a -> a
_ Bool
_ Model
_ = Proxy a -> Proxy a
forall a. a -> a
id
  {-# INLINE liftEvalSym #-}

instance (Integral a, EvalSym a) => EvalSym (Ratio a) where
  evalSym :: Bool -> Model -> Ratio a -> Ratio a
evalSym Bool
fillDefault Model
model Ratio a
r =
    Bool -> Model -> a -> a
forall a. EvalSym a => Bool -> Model -> a -> a
evalSym Bool
fillDefault Model
model (Ratio a -> a
forall a. Ratio a -> a
numerator Ratio a
r)
      a -> a -> Ratio a
forall a. Integral a => a -> a -> Ratio a
% Bool -> Model -> a -> a
forall a. EvalSym a => Bool -> Model -> a -> a
evalSym Bool
fillDefault Model
model (Ratio a -> a
forall a. Ratio a -> a
denominator Ratio a
r)

instance (ValidFP eb fb) => EvalSym (FP eb fb) where
  evalSym :: Bool -> Model -> FP eb fb -> FP eb fb
evalSym Bool
_ Model
_ = FP eb fb -> FP eb fb
forall a. a -> a
id

-- Symbolic primitives
#define EVALUATE_SYM_SIMPLE(symtype) \
instance EvalSym symtype where \
  evalSym fillDefault model (symtype t) = \
    symtype $ evalTerm fillDefault model HS.empty t

#define EVALUATE_SYM_BV(symtype) \
instance (KnownNat n, 1 <= n) => EvalSym (symtype n) where \
  evalSym fillDefault model (symtype t) = \
    symtype $ evalTerm fillDefault model HS.empty t

#define EVALUATE_SYM_FUN(cop, op, cons) \
instance EvalSym (op sa sb) where \
  evalSym fillDefault model (cons t) = \
    cons $ evalTerm fillDefault model HS.empty t

#if 1
EVALUATE_SYM_SIMPLE(SymBool)
EVALUATE_SYM_SIMPLE(SymInteger)
EVALUATE_SYM_SIMPLE(SymFPRoundingMode)
EVALUATE_SYM_SIMPLE(SymAlgReal)
EVALUATE_SYM_BV(SymIntN)
EVALUATE_SYM_BV(SymWordN)
EVALUATE_SYM_FUN((=->), (=~>), SymTabularFun)
EVALUATE_SYM_FUN((-->), (-~>), SymGeneralFun)
#endif

instance (ValidFP eb sb) => EvalSym (SymFP eb sb) where
  evalSym :: Bool -> Model -> SymFP eb sb -> SymFP eb sb
evalSym Bool
fillDefault Model
model (SymFP Term (FP eb sb)
t) =
    Term (FP eb sb) -> SymFP eb sb
forall (eb :: Nat) (sb :: Nat). Term (FP eb sb) -> SymFP eb sb
SymFP (Term (FP eb sb) -> SymFP eb sb) -> Term (FP eb sb) -> SymFP eb sb
forall a b. (a -> b) -> a -> b
$ Bool
-> Model
-> HashSet SomeTypedConstantSymbol
-> Term (FP eb sb)
-> Term (FP eb sb)
forall a.
SupportedPrim a =>
Bool
-> Model -> HashSet SomeTypedConstantSymbol -> Term a -> Term a
evalTerm Bool
fillDefault Model
model HashSet SomeTypedConstantSymbol
forall a. HashSet a
HS.empty Term (FP eb sb)
t

derive
  [ ''(),
    ''AssertionError,
    ''VerificationConditions,
    ''NotRepresentableFPError
  ]
  [''EvalSym]

derive
  [ ''Either,
    ''(,),
    ''(,,),
    ''(,,,),
    ''(,,,,),
    ''(,,,,,),
    ''(,,,,,,),
    ''(,,,,,,,),
    ''(,,,,,,,,),
    ''(,,,,,,,,,),
    ''(,,,,,,,,,,),
    ''(,,,,,,,,,,,),
    ''(,,,,,,,,,,,,),
    ''(,,,,,,,,,,,,,),
    ''(,,,,,,,,,,,,,,)
  ]
  [''EvalSym, ''EvalSym1, ''EvalSym2]

derive
  [ ''[],
    ''Maybe,
    ''Identity,
    ''Monoid.Dual,
    ''Monoid.First,
    ''Monoid.Last,
    ''Monoid.Sum,
    ''Monoid.Product,
    ''Down,
    ''ExceptT,
    ''MaybeT,
    ''WriterLazy.WriterT,
    ''WriterStrict.WriterT
  ]
  [''EvalSym, ''EvalSym1]

-- IdentityT
instance (EvalSym1 m, EvalSym a) => EvalSym (IdentityT m a) where
  evalSym :: Bool -> Model -> IdentityT m a -> IdentityT m a
evalSym = Bool -> Model -> IdentityT m a -> IdentityT m a
forall (f :: * -> *) a.
(EvalSym1 f, EvalSym a) =>
Bool -> Model -> f a -> f a
evalSym1
  {-# INLINE evalSym #-}

instance (EvalSym1 m) => EvalSym1 (IdentityT m) where
  liftEvalSym :: forall a.
(Bool -> Model -> a -> a)
-> Bool -> Model -> IdentityT m a -> IdentityT m a
liftEvalSym Bool -> Model -> a -> a
f Bool
fillDefault Model
model (IdentityT m a
a) =
    m a -> IdentityT m a
forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT (m a -> IdentityT m a) -> m a -> IdentityT m a
forall a b. (a -> b) -> a -> b
$ (Bool -> Model -> a -> a) -> Bool -> Model -> m a -> m a
forall a. (Bool -> Model -> a -> a) -> Bool -> Model -> m a -> m a
forall (f :: * -> *) a.
EvalSym1 f =>
(Bool -> Model -> a -> a) -> Bool -> Model -> f a -> f a
liftEvalSym Bool -> Model -> a -> a
f Bool
fillDefault Model
model m a
a
  {-# INLINE liftEvalSym #-}

-- Product
deriving via
  (Default (Product l r a))
  instance
    (EvalSym (l a), EvalSym (r a)) => EvalSym (Product l r a)

deriving via
  (Default1 (Product l r))
  instance
    (EvalSym1 l, EvalSym1 r) => EvalSym1 (Product l r)

-- Sum
deriving via
  (Default (Sum l r a))
  instance
    (EvalSym (l a), EvalSym (r a)) => EvalSym (Sum l r a)

deriving via
  (Default1 (Sum l r))
  instance
    (EvalSym1 l, EvalSym1 r) => EvalSym1 (Sum l r)

-- Compose
deriving via
  (Default (Compose f g a))
  instance
    (EvalSym (f (g a))) => EvalSym (Compose f g a)

instance (EvalSym1 f, EvalSym1 g) => EvalSym1 (Compose f g) where
  liftEvalSym :: forall a.
(Bool -> Model -> a -> a)
-> Bool -> Model -> Compose f g a -> Compose f g a
liftEvalSym Bool -> Model -> a -> a
f Bool
fillDefault Model
m (Compose f (g a)
l) =
    f (g a) -> Compose f g a
forall {k} {k1} (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose (f (g a) -> Compose f g a) -> f (g a) -> Compose f g a
forall a b. (a -> b) -> a -> b
$ (Bool -> Model -> g a -> g a)
-> Bool -> Model -> f (g a) -> f (g a)
forall a. (Bool -> Model -> a -> a) -> Bool -> Model -> f a -> f a
forall (f :: * -> *) a.
EvalSym1 f =>
(Bool -> Model -> a -> a) -> Bool -> Model -> f a -> f a
liftEvalSym ((Bool -> Model -> a -> a) -> Bool -> Model -> g a -> g a
forall a. (Bool -> Model -> a -> a) -> Bool -> Model -> g a -> g a
forall (f :: * -> *) a.
EvalSym1 f =>
(Bool -> Model -> a -> a) -> Bool -> Model -> f a -> f a
liftEvalSym Bool -> Model -> a -> a
f) Bool
fillDefault Model
m f (g a)
l
  {-# INLINE liftEvalSym #-}

-- Const
deriving via (Default (Const a b)) instance (EvalSym a) => EvalSym (Const a b)

deriving via (Default1 (Const a)) instance (EvalSym a) => EvalSym1 (Const a)

-- Alt
deriving via (Default (Alt f a)) instance (EvalSym (f a)) => EvalSym (Alt f a)

deriving via (Default1 (Alt f)) instance (EvalSym1 f) => EvalSym1 (Alt f)

-- Ap
deriving via (Default (Ap f a)) instance (EvalSym (f a)) => EvalSym (Ap f a)

deriving via (Default1 (Ap f)) instance (EvalSym1 f) => EvalSym1 (Ap f)

-- Generic
deriving via (Default (U1 p)) instance EvalSym (U1 p)

deriving via (Default (V1 p)) instance EvalSym (V1 p)

deriving via
  (Default (K1 i c p))
  instance
    (EvalSym c) => EvalSym (K1 i c p)

deriving via
  (Default (M1 i c f p))
  instance
    (EvalSym (f p)) => EvalSym (M1 i c f p)

deriving via
  (Default ((f :+: g) p))
  instance
    (EvalSym (f p), EvalSym (g p)) => EvalSym ((f :+: g) p)

deriving via
  (Default ((f :*: g) p))
  instance
    (EvalSym (f p), EvalSym (g p)) => EvalSym ((f :*: g) p)

deriving via
  (Default (Par1 p))
  instance
    (EvalSym p) => EvalSym (Par1 p)

deriving via
  (Default (Rec1 f p))
  instance
    (EvalSym (f p)) => EvalSym (Rec1 f p)

deriving via
  (Default ((f :.: g) p))
  instance
    (EvalSym (f (g p))) => EvalSym ((f :.: g) p)

instance (EvalSym a, EvalSym b) => EvalSym (a =-> b) where
  evalSym :: Bool -> Model -> (a =-> b) -> a =-> b
evalSym Bool
fillDefault Model
model (TabularFun [(a, b)]
s b
t) =
    [(a, b)] -> b -> a =-> b
forall a b. [(a, b)] -> b -> a =-> b
TabularFun
      (Bool -> Model -> [(a, b)] -> [(a, b)]
forall a. EvalSym a => Bool -> Model -> a -> a
evalSym Bool
fillDefault Model
model [(a, b)]
s)
      (Bool -> Model -> b -> b
forall a. EvalSym a => Bool -> Model -> a -> a
evalSym Bool
fillDefault Model
model b
t)

instance (EvalSym (SymType b)) => EvalSym (a --> b) where
  evalSym :: Bool -> Model -> (a --> b) -> a --> b
evalSym Bool
fillDefault Model
model (GeneralFun TypedConstantSymbol a
s Term b
t) =
    TypedConstantSymbol a -> Term b -> a --> b
forall a b.
(SupportedNonFuncPrim a, SupportedPrim b) =>
TypedConstantSymbol a -> Term b -> a --> b
GeneralFun TypedConstantSymbol a
s (Term b -> a --> b) -> Term b -> a --> b
forall a b. (a -> b) -> a -> b
$
      Bool
-> Model -> HashSet SomeTypedConstantSymbol -> Term b -> Term b
forall a.
SupportedPrim a =>
Bool
-> Model -> HashSet SomeTypedConstantSymbol -> Term a -> Term a
evalTerm Bool
fillDefault Model
model (SomeTypedConstantSymbol -> HashSet SomeTypedConstantSymbol
forall a. Hashable a => a -> HashSet a
HS.singleton (SomeTypedConstantSymbol -> HashSet SomeTypedConstantSymbol)
-> SomeTypedConstantSymbol -> HashSet SomeTypedConstantSymbol
forall a b. (a -> b) -> a -> b
$ TypedConstantSymbol a -> SomeTypedConstantSymbol
forall (knd :: SymbolKind) t.
TypedSymbol knd t -> SomeTypedSymbol knd
someTypedSymbol TypedConstantSymbol a
s) Term b
t