{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.SymInteger
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.SymInteger
  ( SymInteger (SymInteger),
    SymIntegerKey,
  )
where

import Control.DeepSeq (NFData)
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Hashable (Hashable (hashWithSalt))
import qualified Data.Serialize as Cereal
import Data.String (IsString (fromString))
import GHC.Generics (Generic)
import Grisette.Internal.Core.Data.Class.AsKey
  ( AsKey,
    KeyEq (keyEq),
    KeyHashable (keyHashWithSalt),
    shouldUseAsKeyHasSymbolicVersionError,
    shouldUseSymbolicVersionError,
  )
import Grisette.Internal.Core.Data.Class.Function (Apply (FunType, apply))
import Grisette.Internal.Core.Data.Class.Solvable
  ( Solvable (con, conView, ssym, sym),
  )
import Grisette.Internal.Internal.Decl.SymPrim.AllSyms
  ( AllSyms (allSymsS),
    SomeSym (SomeSym),
  )
import Grisette.Internal.SymPrim.Prim.Term
  ( ConRep (ConType),
    LinkedRep (underlyingTerm, wrapTerm),
    PEvalNumTerm
      ( pevalAbsNumTerm,
        pevalAddNumTerm,
        pevalMulNumTerm,
        pevalNegNumTerm,
        pevalSignumNumTerm
      ),
    SymRep (SymType),
    Term,
    conTerm,
    pevalDivIntegralTerm,
    pevalITETerm,
    pevalLeOrdTerm,
    pevalModIntegralTerm,
    pevalQuotIntegralTerm,
    pevalRemIntegralTerm,
    pevalSubNumTerm,
    pformatTerm,
    symTerm,
    typedConstantSymbol,
    pattern ConTerm,
  )
import Language.Haskell.TH.Syntax (Lift)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Data.Proxy

-- | Symbolic (unbounded, mathematical) integer type.
--
-- >>> "a" + 1 :: SymInteger
-- (+ 1 a)
--
-- More operations are available. Please refer to "Grisette.Core#g:symops" for
-- more information.
newtype SymInteger = SymInteger {SymInteger -> Term Integer
underlyingIntegerTerm :: Term Integer}
  deriving ((forall (m :: * -> *). Quote m => SymInteger -> m Exp)
-> (forall (m :: * -> *).
    Quote m =>
    SymInteger -> Code m SymInteger)
-> Lift SymInteger
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => SymInteger -> m Exp
forall (m :: * -> *). Quote m => SymInteger -> Code m SymInteger
$clift :: forall (m :: * -> *). Quote m => SymInteger -> m Exp
lift :: forall (m :: * -> *). Quote m => SymInteger -> m Exp
$cliftTyped :: forall (m :: * -> *). Quote m => SymInteger -> Code m SymInteger
liftTyped :: forall (m :: * -> *). Quote m => SymInteger -> Code m SymInteger
Lift, SymInteger -> ()
(SymInteger -> ()) -> NFData SymInteger
forall a. (a -> ()) -> NFData a
$crnf :: SymInteger -> ()
rnf :: SymInteger -> ()
NFData, (forall x. SymInteger -> Rep SymInteger x)
-> (forall x. Rep SymInteger x -> SymInteger) -> Generic SymInteger
forall x. Rep SymInteger x -> SymInteger
forall x. SymInteger -> Rep SymInteger x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SymInteger -> Rep SymInteger x
from :: forall x. SymInteger -> Rep SymInteger x
$cto :: forall x. Rep SymInteger x -> SymInteger
to :: forall x. Rep SymInteger x -> SymInteger
Generic)

-- | t'SymInteger' type with identity equality.
type SymIntegerKey = AsKey SymInteger

instance ConRep SymInteger where
  type ConType SymInteger = Integer

instance SymRep Integer where
  type SymType Integer = SymInteger

instance LinkedRep Integer SymInteger where
  underlyingTerm :: SymInteger -> Term Integer
underlyingTerm (SymInteger Term Integer
a) = Term Integer
a
  wrapTerm :: Term Integer -> SymInteger
wrapTerm = Term Integer -> SymInteger
SymInteger

instance Apply SymInteger where
  type FunType SymInteger = SymInteger
  apply :: SymInteger -> FunType SymInteger
apply = SymInteger -> FunType SymInteger
SymInteger -> SymInteger
forall a. a -> a
id

instance Num SymInteger where
  (SymInteger Term Integer
l) + :: SymInteger -> SymInteger -> SymInteger
+ (SymInteger Term Integer
r) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalNumTerm t => Term t -> Term t -> Term t
pevalAddNumTerm Term Integer
l Term Integer
r
  (SymInteger Term Integer
l) - :: SymInteger -> SymInteger -> SymInteger
- (SymInteger Term Integer
r) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalNumTerm t => Term t -> Term t -> Term t
pevalSubNumTerm Term Integer
l Term Integer
r
  (SymInteger Term Integer
l) * :: SymInteger -> SymInteger -> SymInteger
* (SymInteger Term Integer
r) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalNumTerm t => Term t -> Term t -> Term t
pevalMulNumTerm Term Integer
l Term Integer
r
  negate :: SymInteger -> SymInteger
negate (SymInteger Term Integer
v) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer
forall t. PEvalNumTerm t => Term t -> Term t
pevalNegNumTerm Term Integer
v
  abs :: SymInteger -> SymInteger
abs (SymInteger Term Integer
v) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer
forall t. PEvalNumTerm t => Term t -> Term t
pevalAbsNumTerm Term Integer
v
  signum :: SymInteger -> SymInteger
signum (SymInteger Term Integer
v) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer
forall t. PEvalNumTerm t => Term t -> Term t
pevalSignumNumTerm Term Integer
v
  fromInteger :: Integer -> SymInteger
fromInteger = Integer -> SymInteger
forall c t. Solvable c t => c -> t
con

{-# NOINLINE [1] enumDeltaSymInteger #-}
enumDeltaSymInteger :: SymInteger -> SymInteger -> [SymInteger]
enumDeltaSymInteger :: SymInteger -> SymInteger -> [SymInteger]
enumDeltaSymInteger SymInteger
x SymInteger
d = SymInteger
x SymInteger -> [SymInteger] -> [SymInteger]
forall a b. a -> b -> b
`seq` (SymInteger
x SymInteger -> [SymInteger] -> [SymInteger]
forall a. a -> [a] -> [a]
: SymInteger -> SymInteger -> [SymInteger]
enumDeltaSymInteger (SymInteger
d SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
+ SymInteger
x) SymInteger
d)

instance Enum SymInteger where
  succ :: SymInteger -> SymInteger
succ SymInteger
x = SymInteger
x SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
+ SymInteger
1
  pred :: SymInteger -> SymInteger
pred SymInteger
x = SymInteger
x SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
- SymInteger
1
  toEnum :: Int -> SymInteger
toEnum = Int -> SymInteger
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  fromEnum :: SymInteger -> Int
fromEnum = [Char] -> SymInteger -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"fromEnum: fromEnum isn't supported for SymInteger"
  enumFrom :: SymInteger -> [SymInteger]
enumFrom SymInteger
x = SymInteger -> SymInteger -> [SymInteger]
enumDeltaSymInteger SymInteger
x SymInteger
1
  {-# INLINE enumFrom #-}
  enumFromThen :: SymInteger -> SymInteger -> [SymInteger]
enumFromThen SymInteger
x SymInteger
y = SymInteger -> SymInteger -> [SymInteger]
enumDeltaSymInteger SymInteger
x (SymInteger
y SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
- SymInteger
x)
  {-# INLINE enumFromThen #-}
  enumFromTo :: SymInteger -> SymInteger -> [SymInteger]
enumFromTo = [Char] -> SymInteger -> SymInteger -> [SymInteger]
forall a. HasCallStack => [Char] -> a
error [Char]
"enumFromTo: enumFromTo isn't supported for SymInteger"
  enumFromThenTo :: SymInteger -> SymInteger -> SymInteger -> [SymInteger]
enumFromThenTo =
    [Char] -> SymInteger -> SymInteger -> SymInteger -> [SymInteger]
forall a. HasCallStack => [Char] -> a
error [Char]
"enumFromThenTo: enumFromThenTo isn't supported for SymInteger"

-- | Except for 'max' and 'min', the other functions will crash the program.
--
-- 'SymInteger' cannot be compared concretely.
--
-- If you want symbolic version of the comparison operators, use
-- t'Grisette.Core.SymOrd' instead.
instance Ord SymInteger where
  < :: SymInteger -> SymInteger -> Bool
(<) = [Char] -> [Char] -> [Char] -> SymInteger -> SymInteger -> Bool
forall a. HasCallStack => [Char] -> [Char] -> [Char] -> a
shouldUseSymbolicVersionError [Char]
"SymInteger" [Char]
"(<)" [Char]
"(.<)"
  <= :: SymInteger -> SymInteger -> Bool
(<=) = [Char] -> [Char] -> [Char] -> SymInteger -> SymInteger -> Bool
forall a. HasCallStack => [Char] -> [Char] -> [Char] -> a
shouldUseSymbolicVersionError [Char]
"SymInteger" [Char]
"(<=)" [Char]
"(.<=)"
  >= :: SymInteger -> SymInteger -> Bool
(>=) = [Char] -> [Char] -> [Char] -> SymInteger -> SymInteger -> Bool
forall a. HasCallStack => [Char] -> [Char] -> [Char] -> a
shouldUseSymbolicVersionError [Char]
"SymInteger" [Char]
"(>=)" [Char]
"(.>=)"
  > :: SymInteger -> SymInteger -> Bool
(>) = [Char] -> [Char] -> [Char] -> SymInteger -> SymInteger -> Bool
forall a. HasCallStack => [Char] -> [Char] -> [Char] -> a
shouldUseSymbolicVersionError [Char]
"SymInteger" [Char]
"(>)" [Char]
"(.>)"
  max :: SymInteger -> SymInteger -> SymInteger
max (SymInteger Term Integer
l) (SymInteger Term Integer
r) =
    Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Integer -> Term Integer -> Term Integer
forall t.
SupportedPrim t =>
Term Bool -> Term t -> Term t -> Term t
pevalITETerm (Term Integer -> Term Integer -> Term Bool
forall t. PEvalOrdTerm t => Term t -> Term t -> Term Bool
pevalLeOrdTerm Term Integer
l Term Integer
r) Term Integer
r Term Integer
l
  min :: SymInteger -> SymInteger -> SymInteger
min (SymInteger Term Integer
l) (SymInteger Term Integer
r) =
    Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Bool -> Term Integer -> Term Integer -> Term Integer
forall t.
SupportedPrim t =>
Term Bool -> Term t -> Term t -> Term t
pevalITETerm (Term Integer -> Term Integer -> Term Bool
forall t. PEvalOrdTerm t => Term t -> Term t -> Term Bool
pevalLeOrdTerm Term Integer
l Term Integer
r) Term Integer
l Term Integer
r
  compare :: SymInteger -> SymInteger -> Ordering
compare = [Char] -> [Char] -> [Char] -> SymInteger -> SymInteger -> Ordering
forall a. HasCallStack => [Char] -> [Char] -> [Char] -> a
shouldUseSymbolicVersionError [Char]
"SymInteger" [Char]
"compare" [Char]
"symCompare"

instance Real SymInteger where
  toRational :: SymInteger -> Rational
toRational SymInteger
_ = [Char] -> Rational
forall a. HasCallStack => [Char] -> a
error [Char]
"toRational: toRational isn't supported for SymInteger"

-- | The functions are total and will not throw errors. The result is considered
-- undefined if the divisor is 0.
--
-- It is the responsibility of the caller to ensure that the divisor is not
-- zero with the symbolic constraints, or use the t'Grisette.Core.DivOr' or
-- t'Grisette.Core.SafeDiv' classes.
instance Integral SymInteger where
  toInteger :: SymInteger -> Integer
toInteger = [Char] -> SymInteger -> Integer
forall a. HasCallStack => [Char] -> a
error [Char]
"toInteger: toInteger isn't supported for SymInteger"
  div :: SymInteger -> SymInteger -> SymInteger
div (SymInteger Term Integer
l) (SymInteger Term Integer
r) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalDivIntegralTerm Term Integer
l Term Integer
r
  mod :: SymInteger -> SymInteger -> SymInteger
mod (SymInteger Term Integer
l) (SymInteger Term Integer
r) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalModIntegralTerm Term Integer
l Term Integer
r
  quot :: SymInteger -> SymInteger -> SymInteger
quot (SymInteger Term Integer
l) (SymInteger Term Integer
r) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalQuotIntegralTerm Term Integer
l Term Integer
r
  rem :: SymInteger -> SymInteger -> SymInteger
rem (SymInteger Term Integer
l) (SymInteger Term Integer
r) = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalRemIntegralTerm Term Integer
l Term Integer
r
  divMod :: SymInteger -> SymInteger -> (SymInteger, SymInteger)
divMod (SymInteger Term Integer
l) (SymInteger Term Integer
r) =
    (Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalDivIntegralTerm Term Integer
l Term Integer
r, Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalModIntegralTerm Term Integer
l Term Integer
r)
  quotRem :: SymInteger -> SymInteger -> (SymInteger, SymInteger)
quotRem (SymInteger Term Integer
l) (SymInteger Term Integer
r) =
    (Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalQuotIntegralTerm Term Integer
l Term Integer
r, Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> Term Integer -> SymInteger
forall a b. (a -> b) -> a -> b
$ Term Integer -> Term Integer -> Term Integer
forall t. PEvalDivModIntegralTerm t => Term t -> Term t -> Term t
pevalRemIntegralTerm Term Integer
l Term Integer
r)

-- | This will crash the program.
--
-- 'SymInteger' cannot be compared concretely.
--
-- If you want to use the type as keys in hash maps based on term equality, say
-- memo table, you should use @'AsKey' 'SymInteger'@ instead.
--
-- If you want symbolic version of the equality operator, use
-- t'Grisette.Core.SymEq' instead.
instance Eq SymInteger where
  == :: SymInteger -> SymInteger -> Bool
(==) = [Char] -> [Char] -> [Char] -> SymInteger -> SymInteger -> Bool
forall a. HasCallStack => [Char] -> [Char] -> [Char] -> a
shouldUseAsKeyHasSymbolicVersionError [Char]
"SymInteger" [Char]
"(==)" [Char]
"(.==)"

instance KeyEq SymInteger where
  keyEq :: SymInteger -> SymInteger -> Bool
keyEq (SymInteger Term Integer
l) (SymInteger Term Integer
r) = Term Integer
l Term Integer -> Term Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Term Integer
r

instance KeyHashable SymInteger where
  keyHashWithSalt :: Int -> SymInteger -> Int
keyHashWithSalt Int
s (SymInteger Term Integer
v) = Int
s Int -> Term Integer -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Term Integer
v

instance Solvable Integer SymInteger where
  con :: Integer -> SymInteger
con = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger)
-> (Integer -> Term Integer) -> Integer -> SymInteger
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Term Integer
forall t. SupportedPrim t => t -> Term t
conTerm
  sym :: Symbol -> SymInteger
sym = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger)
-> (Symbol -> Term Integer) -> Symbol -> SymInteger
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedSymbol 'ConstantKind Integer -> Term Integer
forall (knd :: SymbolKind) t. TypedSymbol knd t -> Term t
symTerm (TypedSymbol 'ConstantKind Integer -> Term Integer)
-> (Symbol -> TypedSymbol 'ConstantKind Integer)
-> Symbol
-> Term Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Symbol -> TypedSymbol 'ConstantKind Integer
forall t.
SupportedNonFuncPrim t =>
Symbol -> TypedSymbol 'ConstantKind t
typedConstantSymbol
  conView :: SymInteger -> Maybe Integer
conView (SymInteger (ConTerm Integer
t)) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
t
  conView SymInteger
_ = Maybe Integer
forall a. Maybe a
Nothing

instance IsString SymInteger where
  fromString :: [Char] -> SymInteger
fromString = Identifier -> SymInteger
forall c t. Solvable c t => Identifier -> t
ssym (Identifier -> SymInteger)
-> ([Char] -> Identifier) -> [Char] -> SymInteger
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Identifier
forall a. IsString a => [Char] -> a
fromString

instance Show SymInteger where
  show :: SymInteger -> [Char]
show (SymInteger Term Integer
t) = Term Integer -> [Char]
forall t. Term t -> [Char]
pformatTerm Term Integer
t

instance AllSyms SymInteger where
  allSymsS :: SymInteger -> [SomeSym] -> [SomeSym]
allSymsS SymInteger
v = (SymInteger -> SomeSym
forall con sym. LinkedRep con sym => sym -> SomeSym
SomeSym SymInteger
v SomeSym -> [SomeSym] -> [SomeSym]
forall a. a -> [a] -> [a]
:)

instance Serial SymInteger where
  serialize :: forall (m :: * -> *). MonadPut m => SymInteger -> m ()
serialize = Term Integer -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Term Integer -> m ()
serialize (Term Integer -> m ())
-> (SymInteger -> Term Integer) -> SymInteger -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymInteger -> Term Integer
underlyingIntegerTerm
  deserialize :: forall (m :: * -> *). MonadGet m => m SymInteger
deserialize = Term Integer -> SymInteger
SymInteger (Term Integer -> SymInteger) -> m (Term Integer) -> m SymInteger
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Term Integer)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (Term Integer)
deserialize

instance Cereal.Serialize SymInteger where
  put :: Putter SymInteger
put = Putter SymInteger
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SymInteger -> m ()
serialize
  get :: Get SymInteger
get = Get SymInteger
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SymInteger
deserialize

instance Binary.Binary SymInteger where
  put :: SymInteger -> Put
put = SymInteger -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SymInteger -> m ()
serialize
  get :: Get SymInteger
get = Get SymInteger
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SymInteger
deserialize