{-# HLINT ignore "Unused LANGUAGE pragma" #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.SymGeneralFun
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.SymGeneralFun
  ( type (-~>) (SymGeneralFun),
    (-->),
  )
where

import Control.DeepSeq (NFData (rnf))
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Hashable (Hashable (hashWithSalt))
import qualified Data.Serialize as Cereal
import Data.String (IsString (fromString))
import GHC.Generics (Generic)
import Grisette.Internal.Core.Data.Class.Function
  ( Apply (FunType, apply),
    Function ((#)),
  )
import Grisette.Internal.Core.Data.Class.Solvable
  ( Solvable (con, conView, ssym, sym),
  )
import Grisette.Internal.Internal.Decl.SymPrim.AllSyms
  ( AllSyms (allSymsS),
    SomeSym (SomeSym),
  )
import Grisette.Internal.SymPrim.GeneralFun (buildGeneralFun, type (-->))
import Grisette.Internal.SymPrim.Prim.Term
  ( ConRep (ConType),
    LinkedRep (underlyingTerm, wrapTerm),
    PEvalApplyTerm (pevalApplyTerm),
    SupportedNonFuncPrim,
    SupportedPrim,
    SymRep (SymType),
    Term,
    TypedConstantSymbol,
    conTerm,
    pformatTerm,
    symTerm,
    typedAnySymbol,
    pattern ConTerm,
  )
import Language.Haskell.TH.Syntax (Lift (liftTyped))

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Data.Proxy

-- |
-- Symbolic general function type.
--
-- >>> f' = "f" :: SymInteger -~> SymInteger
-- >>> f = (f' #)
-- >>> f 1
-- (apply f 1)
--
-- >>> f' = con ("a" --> "a" + 1) :: SymInteger -~> SymInteger
-- >>> f'
-- \(arg@0 :: Integer) -> (+ 1 arg@0)
-- >>> f = (f' #)
-- >>> f 1
-- 2
-- >>> f 2
-- 3
-- >>> f 3
-- 4
-- >>> f "b"
-- (+ 1 b)
data sa -~> sb where
  SymGeneralFun ::
    ( LinkedRep ca sa,
      LinkedRep cb sb,
      SupportedPrim (ca --> cb),
      SupportedNonFuncPrim ca
    ) =>
    Term (ca --> cb) ->
    sa -~> sb

infixr 0 -~>

-- | Construction of general symbolic functions.
--
-- >>> f = "a" --> "a" + 1 :: Integer --> Integer
-- >>> f
-- \(arg@0 :: Integer) -> (+ 1 arg@0)
--
-- This general symbolic function needs to be applied to symbolic values:
--
-- >>> f # ("a" :: SymInteger)
-- (+ 1 a)
-- >>> f # (2 :: SymInteger)
-- 3
(-->) ::
  (SupportedNonFuncPrim ca, SupportedPrim cb, LinkedRep cb sb) =>
  TypedConstantSymbol ca ->
  sb ->
  ca --> cb
--> :: forall ca cb sb.
(SupportedNonFuncPrim ca, SupportedPrim cb, LinkedRep cb sb) =>
TypedConstantSymbol ca -> sb -> ca --> cb
(-->) TypedConstantSymbol ca
arg = TypedConstantSymbol ca -> Term cb -> ca --> cb
forall a b.
(SupportedNonFuncPrim a, SupportedPrim b) =>
TypedConstantSymbol a -> Term b -> a --> b
buildGeneralFun TypedConstantSymbol ca
arg (Term cb -> ca --> cb) -> (sb -> Term cb) -> sb -> ca --> cb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sb -> Term cb
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm

infixr 0 -->

data ARG = ARG
  deriving (ARG -> ARG -> Bool
(ARG -> ARG -> Bool) -> (ARG -> ARG -> Bool) -> Eq ARG
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ARG -> ARG -> Bool
== :: ARG -> ARG -> Bool
$c/= :: ARG -> ARG -> Bool
/= :: ARG -> ARG -> Bool
Eq, Eq ARG
Eq ARG =>
(ARG -> ARG -> Ordering)
-> (ARG -> ARG -> Bool)
-> (ARG -> ARG -> Bool)
-> (ARG -> ARG -> Bool)
-> (ARG -> ARG -> Bool)
-> (ARG -> ARG -> ARG)
-> (ARG -> ARG -> ARG)
-> Ord ARG
ARG -> ARG -> Bool
ARG -> ARG -> Ordering
ARG -> ARG -> ARG
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ARG -> ARG -> Ordering
compare :: ARG -> ARG -> Ordering
$c< :: ARG -> ARG -> Bool
< :: ARG -> ARG -> Bool
$c<= :: ARG -> ARG -> Bool
<= :: ARG -> ARG -> Bool
$c> :: ARG -> ARG -> Bool
> :: ARG -> ARG -> Bool
$c>= :: ARG -> ARG -> Bool
>= :: ARG -> ARG -> Bool
$cmax :: ARG -> ARG -> ARG
max :: ARG -> ARG -> ARG
$cmin :: ARG -> ARG -> ARG
min :: ARG -> ARG -> ARG
Ord, (forall (m :: * -> *). Quote m => ARG -> m Exp)
-> (forall (m :: * -> *). Quote m => ARG -> Code m ARG) -> Lift ARG
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => ARG -> m Exp
forall (m :: * -> *). Quote m => ARG -> Code m ARG
$clift :: forall (m :: * -> *). Quote m => ARG -> m Exp
lift :: forall (m :: * -> *). Quote m => ARG -> m Exp
$cliftTyped :: forall (m :: * -> *). Quote m => ARG -> Code m ARG
liftTyped :: forall (m :: * -> *). Quote m => ARG -> Code m ARG
Lift, Int -> ARG -> ShowS
[ARG] -> ShowS
ARG -> String
(Int -> ARG -> ShowS)
-> (ARG -> String) -> ([ARG] -> ShowS) -> Show ARG
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ARG -> ShowS
showsPrec :: Int -> ARG -> ShowS
$cshow :: ARG -> String
show :: ARG -> String
$cshowList :: [ARG] -> ShowS
showList :: [ARG] -> ShowS
Show, (forall x. ARG -> Rep ARG x)
-> (forall x. Rep ARG x -> ARG) -> Generic ARG
forall x. Rep ARG x -> ARG
forall x. ARG -> Rep ARG x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ARG -> Rep ARG x
from :: forall x. ARG -> Rep ARG x
$cto :: forall x. Rep ARG x -> ARG
to :: forall x. Rep ARG x -> ARG
Generic)

instance NFData ARG where
  rnf :: ARG -> ()
rnf ARG
ARG = ()

instance Hashable ARG where
  hashWithSalt :: Int -> ARG -> Int
hashWithSalt Int
s ARG
ARG = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
0 :: Int)

instance Lift (sa -~> sb) where
  liftTyped :: forall (m :: * -> *). Quote m => (sa -~> sb) -> Code m (sa -~> sb)
liftTyped (SymGeneralFun Term (ca --> cb)
t) = [||Term (ca --> cb) -> sa -~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca --> cb),
 SupportedNonFuncPrim ca) =>
Term (ca --> cb) -> sa -~> sb
SymGeneralFun Term (ca --> cb)
t||]

instance NFData (sa -~> sb) where
  rnf :: (sa -~> sb) -> ()
rnf (SymGeneralFun Term (ca --> cb)
t) = Term (ca --> cb) -> ()
forall a. NFData a => a -> ()
rnf Term (ca --> cb)
t

instance (ConRep a, ConRep b) => ConRep (a -~> b) where
  type ConType (a -~> b) = ConType a --> ConType b

instance
  ( SymRep ca,
    SymRep cb,
    SupportedPrim (ca --> cb)
  ) =>
  SymRep (ca --> cb)
  where
  type SymType (ca --> cb) = SymType ca -~> SymType cb

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim cb,
    SupportedPrim (ca --> cb),
    SupportedNonFuncPrim ca
  ) =>
  LinkedRep (ca --> cb) (sa -~> sb)
  where
  underlyingTerm :: (sa -~> sb) -> Term (ca --> cb)
underlyingTerm (SymGeneralFun Term (ca --> cb)
a) = Term (ca --> cb)
Term (ca --> cb)
a
  wrapTerm :: Term (ca --> cb) -> sa -~> sb
wrapTerm = Term (ca --> cb) -> sa -~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca --> cb),
 SupportedNonFuncPrim ca) =>
Term (ca --> cb) -> sa -~> sb
SymGeneralFun

instance Function (sa -~> sb) sa sb where
  (SymGeneralFun Term (ca --> cb)
f) # :: (sa -~> sb) -> sa -> sb
# sa
t = Term cb -> sb
forall con sym. LinkedRep con sym => Term con -> sym
wrapTerm (Term cb -> sb) -> Term cb -> sb
forall a b. (a -> b) -> a -> b
$ Term (ca --> cb) -> Term ca -> Term cb
forall f a b. PEvalApplyTerm f a b => Term f -> Term a -> Term b
pevalApplyTerm Term (ca --> cb)
f (sa -> Term ca
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm sa
t)

instance (Apply st) => Apply (sa -~> st) where
  type FunType (sa -~> st) = sa -> FunType st
  apply :: (sa -~> st) -> FunType (sa -~> st)
apply sa -~> st
uf sa
a = st -> FunType st
forall uf. Apply uf => uf -> FunType uf
apply (sa -~> st
uf (sa -~> st) -> sa -> st
forall f arg ret. Function f arg ret => f -> arg -> ret
# sa
a)

instance
  ( SupportedNonFuncPrim ca,
    LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca --> cb)
  ) =>
  Solvable (ca --> cb) (sa -~> sb)
  where
  con :: (ca --> cb) -> sa -~> sb
con = Term (ca --> cb) -> sa -~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca --> cb),
 SupportedNonFuncPrim ca) =>
Term (ca --> cb) -> sa -~> sb
SymGeneralFun (Term (ca --> cb) -> sa -~> sb)
-> ((ca --> cb) -> Term (ca --> cb)) -> (ca --> cb) -> sa -~> sb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ca --> cb) -> Term (ca --> cb)
forall t. SupportedPrim t => t -> Term t
conTerm
  sym :: Symbol -> sa -~> sb
sym = Term (ca --> cb) -> sa -~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca --> cb),
 SupportedNonFuncPrim ca) =>
Term (ca --> cb) -> sa -~> sb
SymGeneralFun (Term (ca --> cb) -> sa -~> sb)
-> (Symbol -> Term (ca --> cb)) -> Symbol -> sa -~> sb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedSymbol 'AnyKind (ca --> cb) -> Term (ca --> cb)
forall (knd :: SymbolKind) t. TypedSymbol knd t -> Term t
symTerm (TypedSymbol 'AnyKind (ca --> cb) -> Term (ca --> cb))
-> (Symbol -> TypedSymbol 'AnyKind (ca --> cb))
-> Symbol
-> Term (ca --> cb)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Symbol -> TypedSymbol 'AnyKind (ca --> cb)
forall t. SupportedPrim t => Symbol -> TypedSymbol 'AnyKind t
typedAnySymbol
  conView :: (sa -~> sb) -> Maybe (ca --> cb)
conView (SymGeneralFun (ConTerm ca --> cb
t)) = (ca --> cb) -> Maybe (ca --> cb)
forall a. a -> Maybe a
Just ca --> cb
ca --> cb
t
  conView sa -~> sb
_ = Maybe (ca --> cb)
forall a. Maybe a
Nothing

instance
  ( SupportedPrim (ca --> cb),
    SupportedNonFuncPrim ca,
    LinkedRep ca sa,
    LinkedRep cb sb
  ) =>
  IsString (sa -~> sb)
  where
  fromString :: String -> sa -~> sb
fromString = Identifier -> sa -~> sb
forall c t. Solvable c t => Identifier -> t
ssym (Identifier -> sa -~> sb)
-> (String -> Identifier) -> String -> sa -~> sb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Identifier
forall a. IsString a => String -> a
fromString

instance Show (sa -~> sb) where
  show :: (sa -~> sb) -> String
show (SymGeneralFun Term (ca --> cb)
t) = Term (ca --> cb) -> String
forall t. Term t -> String
pformatTerm Term (ca --> cb)
t

instance Eq (sa -~> sb) where
  SymGeneralFun Term (ca --> cb)
l == :: (sa -~> sb) -> (sa -~> sb) -> Bool
== SymGeneralFun Term (ca --> cb)
r = Term (ca --> cb)
l Term (ca --> cb) -> Term (ca --> cb) -> Bool
forall a. Eq a => a -> a -> Bool
== Term (ca --> cb)
Term (ca --> cb)
r

instance Hashable (sa -~> sb) where
  hashWithSalt :: Int -> (sa -~> sb) -> Int
hashWithSalt Int
s (SymGeneralFun Term (ca --> cb)
v) = Int
s Int -> Term (ca --> cb) -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Term (ca --> cb)
v

instance AllSyms (sa -~> sb) where
  allSymsS :: (sa -~> sb) -> [SomeSym] -> [SomeSym]
allSymsS v :: sa -~> sb
v@SymGeneralFun {} = ((sa -~> sb) -> SomeSym
forall con sym. LinkedRep con sym => sym -> SomeSym
SomeSym sa -~> sb
v SomeSym -> [SomeSym] -> [SomeSym]
forall a. a -> [a] -> [a]
:)

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca --> cb),
    SupportedNonFuncPrim ca
  ) =>
  Serial (sa -~> sb)
  where
  serialize :: forall (m :: * -> *). MonadPut m => (sa -~> sb) -> m ()
serialize = Term (ca --> cb) -> m ()
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Term (ca --> cb) -> m ()
serialize (Term (ca --> cb) -> m ())
-> ((sa -~> sb) -> Term (ca --> cb)) -> (sa -~> sb) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (sa -~> sb) -> Term (ca --> cb)
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm
  deserialize :: forall (m :: * -> *). MonadGet m => m (sa -~> sb)
deserialize = Term (ca --> cb) -> sa -~> sb
forall ca sa cb sb.
(LinkedRep ca sa, LinkedRep cb sb, SupportedPrim (ca --> cb),
 SupportedNonFuncPrim ca) =>
Term (ca --> cb) -> sa -~> sb
SymGeneralFun (Term (ca --> cb) -> sa -~> sb)
-> m (Term (ca --> cb)) -> m (sa -~> sb)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Term (ca --> cb))
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (Term (ca --> cb))
deserialize

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca --> cb),
    SupportedNonFuncPrim ca
  ) =>
  Cereal.Serialize (sa -~> sb)
  where
  put :: Putter (sa -~> sb)
put = Putter (sa -~> sb)
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => (sa -~> sb) -> m ()
serialize
  get :: Get (sa -~> sb)
get = Get (sa -~> sb)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (sa -~> sb)
deserialize

instance
  ( LinkedRep ca sa,
    LinkedRep cb sb,
    SupportedPrim (ca --> cb),
    SupportedNonFuncPrim ca
  ) =>
  Binary.Binary (sa -~> sb)
  where
  put :: (sa -~> sb) -> Put
put = (sa -~> sb) -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => (sa -~> sb) -> m ()
serialize
  get :: Get (sa -~> sb)
get = Get (sa -~> sb)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (sa -~> sb)
deserialize