{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Unified.UnifiedFun
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Unified.UnifiedFun
  ( UnifiedFunConstraint,
    UnifiedFun (..),
    unifiedFunInstanceName,
    genUnifiedFunInstance,
    GetFun2,
    GetFun3,
    GetFun4,
    GetFun5,
    GetFun6,
    GetFun7,
    GetFun8,
  )
where

#if MIN_VERSION_base(4,20,0)
#else
import Data.Foldable (Foldable (foldl'))
#endif

import Control.DeepSeq (NFData)
import Data.Binary (Binary)
import Data.Bytes.Serial (Serial)
import Data.Hashable (Hashable)
import qualified Data.Kind
import Data.Serialize (Serialize)
import Data.Typeable (Typeable)
import GHC.TypeLits (KnownNat, Nat, type (<=))
import Grisette.Internal.Core.Data.Class.EvalSym (EvalSym)
import Grisette.Internal.Core.Data.Class.ExtractSym (ExtractSym)
import Grisette.Internal.Core.Data.Class.Function (Apply (FunType), Function)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.PPrint (PPrint)
import Grisette.Internal.Core.Data.Class.SubstSym (SubstSym)
import Grisette.Internal.Core.Data.Class.ToCon (ToCon)
import Grisette.Internal.Core.Data.Class.ToSym (ToSym)
import Grisette.Internal.SymPrim.AlgReal (AlgReal)
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.FP (FP, ValidFP)
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal)
import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN)
import Grisette.Internal.SymPrim.SymBool (SymBool)
import Grisette.Internal.SymPrim.SymFP (SymFP)
import Grisette.Internal.SymPrim.SymInteger (SymInteger)
import Grisette.Internal.SymPrim.SymTabularFun (type (=~>))
import Grisette.Internal.SymPrim.TabularFun (type (=->))
import Grisette.Internal.Unified.EvalModeTag (EvalModeTag (C, S))
import Grisette.Internal.Unified.Theories
  ( TheoryToUnify (UAlgReal, UBool, UFP, UFun, UIntN, UInteger, UWordN),
  )
import Grisette.Internal.Unified.UnifiedAlgReal (GetAlgReal)
import Grisette.Internal.Unified.UnifiedBV (UnifiedBVImpl (GetIntN, GetWordN))
import Grisette.Internal.Unified.UnifiedBool (UnifiedBool (GetBool))
import Grisette.Internal.Unified.UnifiedFP (UnifiedFPImpl (GetFP))
import Grisette.Internal.Unified.UnifiedInteger (GetInteger)
import Language.Haskell.TH
  ( DecsQ,
    Pred,
    Q,
    TyLit (NumTyLit),
    Type (AppT, ConT, ForallT, LitT, VarT),
    appT,
    classD,
    conT,
    instanceD,
    mkName,
    newName,
    promotedT,
    varT,
  )
import qualified Language.Haskell.TH
import Language.Haskell.TH.Datatype.TyVarBndr
  ( kindedTV,
    mapTVFlag,
    specifiedSpec,
    tvName,
  )
import Language.Haskell.TH.Syntax (Lift)

#if MIN_VERSION_template_haskell(2,21,0)
type TyVarBndrVis = Language.Haskell.TH.TyVarBndrVis
#elif MIN_VERSION_template_haskell(2,17,0)
type TyVarBndrVis = Language.Haskell.TH.TyVarBndr ()
#else
type TyVarBndrVis = Language.Haskell.TH.TyVarBndr
#endif

-- | Provide unified function types.
class UnifiedFun (mode :: EvalModeTag) where
  -- | Get a unified function type. Resolves to t'Grisette.SymPrim.=->' in 'C'
  -- mode, and t'Grisette.SymPrim.=~>' in 'S' mode.
  type
    GetFun mode =
      (fun :: Data.Kind.Type -> Data.Kind.Type -> Data.Kind.Type) | fun -> mode

instance UnifiedFun 'C where
  type GetFun 'C = (=->)

instance UnifiedFun 'S where
  type GetFun 'S = (=~>)

-- | The unified function type with 2 arguments.
type GetFun2 mode a b = GetFun mode a b

-- | The unified function type with 3 arguments.
type GetFun3 mode a b c = GetFun mode a (GetFun mode b c)

-- | The unified function type with 4 arguments.
type GetFun4 mode a b c d = GetFun mode a (GetFun mode b (GetFun mode c d))

-- | The unified function type with 5 arguments.
type GetFun5 mode a b c d e =
  GetFun mode a (GetFun mode b (GetFun mode c (GetFun mode d e)))

-- | The unified function type with 6 arguments.
type GetFun6 mode a b c d e f =
  GetFun
    mode
    a
    (GetFun mode b (GetFun mode c (GetFun mode d (GetFun mode e f))))

-- | The unified function type with 7 arguments.
type GetFun7 mode a b c d e f g =
  GetFun
    mode
    a
    ( GetFun
        mode
        b
        (GetFun mode c (GetFun mode d (GetFun mode e (GetFun mode f g))))
    )

-- | The unified function type with 8 arguments.
type GetFun8 mode a b c d e f g h =
  GetFun
    mode
    a
    ( GetFun
        mode
        b
        ( GetFun
            mode
            c
            (GetFun mode d (GetFun mode e (GetFun mode f (GetFun mode g h))))
        )
    )

-- | The constraint for a unified function.
type UnifiedFunConstraint mode a b ca cb sa sb =
  ( Show (GetFun mode a b),
    Binary (GetFun mode a b),
    Serial (GetFun mode a b),
    Serialize (GetFun mode a b),
    NFData (GetFun mode a b),
    Eq (GetFun mode a b),
    EvalSym (GetFun mode a b),
    ExtractSym (GetFun mode a b),
    Mergeable (GetFun mode a b),
    PPrint (GetFun mode a b),
    SubstSym (GetFun mode a b),
    Hashable (GetFun mode a b),
    Lift (GetFun mode a b),
    Typeable (GetFun mode a b),
    ToCon (GetFun mode a b) (ca =-> cb),
    ToCon (sa =~> sb) (GetFun mode a b),
    ToSym (GetFun mode a b) (sa =~> sb),
    ToSym (ca =-> cb) (GetFun mode a b),
    Function (GetFun mode a b) a b,
    Apply (GetFun mode a b),
    FunType (GetFun mode a b) ~ (a -> b)
  )

genInnerUnifiedFunInstance ::
  String ->
  TyVarBndrVis ->
  [Pred] ->
  [TyVarBndrVis] ->
  [(Type, Type, Type)] ->
  DecsQ
genInnerUnifiedFunInstance :: String
-> TyVarBndrVis
-> [Pred]
-> [TyVarBndrVis]
-> [(Pred, Pred, Pred)]
-> DecsQ
genInnerUnifiedFunInstance String
nm TyVarBndrVis
mode [Pred]
preds [TyVarBndrVis]
bndrs [(Pred, Pred, Pred)]
tys = do
  x <- Q [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [Q Dec] -> Q Dec
forall (m :: * -> *).
Quote m =>
m [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [m Dec] -> m Dec
classD ([(Pred, Pred, Pred)] -> Q [Pred]
goPred [(Pred, Pred, Pred)]
tys) (String -> Name
mkName String
nm) (TyVarBndrVis
mode TyVarBndrVis -> [TyVarBndrVis] -> [TyVarBndrVis]
forall a. a -> [a] -> [a]
: [TyVarBndrVis]
bndrs) [] []
  dc <-
    instanceD
      (return preds)
      (applyTypeList (promotedT 'C : additionalTypes))
      []
  ds <-
    instanceD
      (return preds)
      (applyTypeList (promotedT 'S : additionalTypes))
      []
  return [x, dc, ds]
  where
    additionalTypes :: [Q Pred]
additionalTypes = (Name -> Q Pred
forall (m :: * -> *). Quote m => Name -> m Pred
varT (Name -> Q Pred)
-> (TyVarBndrVis -> Name) -> TyVarBndrVis -> Q Pred
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndrVis -> Name
forall flag. TyVarBndr_ flag -> Name
tvName) (TyVarBndrVis -> Q Pred) -> [TyVarBndrVis] -> [Q Pred]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndrVis]
bndrs
    applyTypeList :: [Q Pred] -> Q Pred
applyTypeList = (Q Pred -> Q Pred -> Q Pred) -> Q Pred -> [Q Pred] -> Q Pred
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Q Pred -> Q Pred -> Q Pred
forall (m :: * -> *). Quote m => m Pred -> m Pred -> m Pred
appT (Name -> Q Pred
forall (m :: * -> *). Quote m => Name -> m Pred
conT (String -> Name
mkName String
nm))
    goPred :: [(Type, Type, Type)] -> Q [Pred]
    goPred :: [(Pred, Pred, Pred)] -> Q [Pred]
goPred [] = String -> Q [Pred]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Empty list of function types, at least 2."
    goPred [(Pred, Pred, Pred)
_] = [Pred] -> Q [Pred]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return []
    goPred ((Pred, Pred, Pred)
x : [(Pred, Pred, Pred)]
xs) = do
      p1 <- (Pred, Pred, Pred) -> [(Pred, Pred, Pred)] -> Q Pred
pred (Pred, Pred, Pred)
x [(Pred, Pred, Pred)]
xs
      pr <- goPred xs
      return $ p1 : pr
    listTys :: [(Type, Type, Type)] -> Q (Type, Type, Type)
    listTys :: [(Pred, Pred, Pred)] -> Q (Pred, Pred, Pred)
listTys [] = String -> Q (Pred, Pred, Pred)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Should not happen"
    listTys [(Pred
u, Pred
c, Pred
s)] = (Pred, Pred, Pred) -> Q (Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Pred
u, Pred
c, Pred
s)
    listTys ((Pred
u, Pred
c, Pred
s) : [(Pred, Pred, Pred)]
xs) = do
      (u', c', s') <- [(Pred, Pred, Pred)] -> Q (Pred, Pred, Pred)
listTys [(Pred, Pred, Pred)]
xs
      return
        ( AppT (AppT (AppT (ConT ''GetFun) (VarT $ tvName mode)) u) u',
          AppT (AppT (ConT ''(=->)) c) c',
          AppT (AppT (ConT ''(=~>)) s) s'
        )
    pred :: (Pred, Pred, Pred) -> [(Pred, Pred, Pred)] -> Q Pred
pred (Pred
ua, Pred
ca, Pred
sa) [(Pred, Pred, Pred)]
l = do
      (ub, cb, sb) <- [(Pred, Pred, Pred)] -> Q (Pred, Pred, Pred)
listTys [(Pred, Pred, Pred)]
l
      [t|
        UnifiedFunConstraint
          $(return (VarT $ tvName mode))
          $(return ua)
          $(return ub)
          $(return ca)
          $(return cb)
          $(return sa)
          $(return sb)
        |]

genOuterUnifiedFunInstance ::
  String -> String -> TyVarBndrVis -> [Pred] -> [TyVarBndrVis] -> DecsQ
genOuterUnifiedFunInstance :: String
-> String -> TyVarBndrVis -> [Pred] -> [TyVarBndrVis] -> DecsQ
genOuterUnifiedFunInstance String
nm String
innerName TyVarBndrVis
mode [Pred]
preds [TyVarBndrVis]
bndrs = do
  let bndrs' :: [TyVarBndr_ Specificity]
bndrs' = (BndrVis -> Specificity) -> TyVarBndrVis -> TyVarBndr_ Specificity
forall flag flag'.
(flag -> flag') -> TyVarBndr_ flag -> TyVarBndr_ flag'
mapTVFlag (Specificity -> BndrVis -> Specificity
forall a b. a -> b -> a
const Specificity
specifiedSpec) (TyVarBndrVis -> TyVarBndr_ Specificity)
-> [TyVarBndrVis] -> [TyVarBndr_ Specificity]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndrVis]
bndrs
  x <-
    Q [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [Q Dec] -> Q Dec
forall (m :: * -> *).
Quote m =>
m [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [m Dec] -> m Dec
classD
      ( [Pred] -> Q [Pred]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
          [ [TyVarBndr_ Specificity] -> [Pred] -> Pred -> Pred
ForallT [TyVarBndr_ Specificity]
bndrs' [Pred]
preds (Pred -> Pred) -> Pred -> Pred
forall a b. (a -> b) -> a -> b
$
              (Pred -> Pred -> Pred) -> Pred -> [Pred] -> Pred
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Pred -> Pred -> Pred
AppT (Name -> Pred
ConT (Name -> Pred) -> Name -> Pred
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
innerName) ([Pred] -> Pred) -> [Pred] -> Pred
forall a b. (a -> b) -> a -> b
$
                Name -> Pred
VarT (Name -> Pred) -> (TyVarBndrVis -> Name) -> TyVarBndrVis -> Pred
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndrVis -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndrVis -> Pred) -> [TyVarBndrVis] -> [Pred]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TyVarBndrVis
mode TyVarBndrVis -> [TyVarBndrVis] -> [TyVarBndrVis]
forall a. a -> [a] -> [a]
: [TyVarBndrVis]
bndrs
          ]
      )
      (String -> Name
mkName String
nm)
      [TyVarBndrVis
mode]
      []
      []
  dc <-
    instanceD
      (return [])
      (appT (conT $ mkName nm) (promotedT 'C))
      []
  ds <-
    instanceD
      (return [])
      (appT (conT $ mkName nm) (promotedT 'S))
      []
  return [x, dc, ds]

-- | Generate unified function instance names.
unifiedFunInstanceName :: String -> [TheoryToUnify] -> String
unifiedFunInstanceName :: String -> [TheoryToUnify] -> String
unifiedFunInstanceName String
prefix [TheoryToUnify]
theories =
  String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Fun" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((TheoryToUnify -> String) -> [TheoryToUnify] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap TheoryToUnify -> String
forall a. Show a => a -> String
show [TheoryToUnify]
theories)

-- | Generate unified function instances.
genUnifiedFunInstance :: String -> [TheoryToUnify] -> DecsQ
genUnifiedFunInstance :: String -> [TheoryToUnify] -> DecsQ
genUnifiedFunInstance String
prefix [TheoryToUnify]
theories = do
  modeName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"mode"
  let modeType = Name -> Pred
VarT Name
modeName
  allArgs <- traverse (genArgs modeType) theories
  let baseName = String -> [TheoryToUnify] -> String
unifiedFunInstanceName String
prefix [TheoryToUnify]
theories
  rinner <-
    genInnerUnifiedFunInstance
      baseName
      (kindedTV modeName (ConT ''EvalModeTag))
      (concatMap (\([TyVarBndrVis]
_, [Pred]
p, Pred
_, Pred
_, Pred
_) -> [Pred]
p) allArgs)
      (concatMap (\([TyVarBndrVis]
t, [Pred]
_, Pred
_, Pred
_, Pred
_) -> [TyVarBndrVis]
t) allArgs)
      ((\([TyVarBndrVis]
_, [Pred]
_, Pred
u, Pred
c, Pred
s) -> (Pred
u, Pred
c, Pred
s)) <$> allArgs)
  router <-
    if all (\([TyVarBndrVis]
bndr, [Pred]
_, Pred
_, Pred
_, Pred
_) -> [TyVarBndrVis] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TyVarBndrVis]
bndr) allArgs
      then return []
      else
        genOuterUnifiedFunInstance
          ("All" ++ baseName)
          baseName
          (kindedTV modeName (ConT ''EvalModeTag))
          (concatMap (\([TyVarBndrVis]
_, [Pred]
p, Pred
_, Pred
_, Pred
_) -> [Pred]
p) allArgs)
          (concatMap (\([TyVarBndrVis]
t, [Pred]
_, Pred
_, Pred
_, Pred
_) -> [TyVarBndrVis]
t) allArgs)
  return $ rinner ++ router
  where
    genArgs ::
      Type -> TheoryToUnify -> Q ([TyVarBndrVis], [Pred], Type, Type, Type)
    genArgs :: Pred
-> TheoryToUnify -> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
genArgs Pred
mode TheoryToUnify
UBool =
      ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
-> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [],
          [],
          Pred -> Pred -> Pred
AppT (Name -> Pred
ConT ''GetBool) Pred
mode,
          Name -> Pred
ConT ''Bool,
          Name -> Pred
ConT ''SymBool
        )
    genArgs Pred
mode TheoryToUnify
UIntN = do
      n <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"n"
      let nType = Name -> Pred
VarT Name
n
      return
        ( [kindedTV n (ConT ''Nat)],
          [ AppT (ConT ''KnownNat) nType,
            AppT (AppT (ConT ''(<=)) (LitT $ NumTyLit 1)) nType
          ],
          AppT (AppT (ConT ''GetIntN) mode) nType,
          AppT (ConT ''IntN) nType,
          AppT (ConT ''SymIntN) nType
        )
    genArgs Pred
mode TheoryToUnify
UWordN = do
      n <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"n"
      let nType = Name -> Pred
VarT Name
n
      return
        ( [kindedTV n (ConT ''Nat)],
          [ AppT (ConT ''KnownNat) nType,
            AppT (AppT (ConT ''(<=)) (LitT $ NumTyLit 1)) nType
          ],
          AppT (AppT (ConT ''GetWordN) mode) nType,
          AppT (ConT ''WordN) nType,
          AppT (ConT ''SymWordN) nType
        )
    genArgs Pred
mode TheoryToUnify
UInteger =
      ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
-> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [],
          [],
          Pred -> Pred -> Pred
AppT (Name -> Pred
ConT ''GetInteger) Pred
mode,
          Name -> Pred
ConT ''Integer,
          Name -> Pred
ConT ''SymInteger
        )
    genArgs Pred
mode TheoryToUnify
UAlgReal =
      ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
-> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [],
          [],
          Pred -> Pred -> Pred
AppT (Name -> Pred
ConT ''GetAlgReal) Pred
mode,
          Name -> Pred
ConT ''AlgReal,
          Name -> Pred
ConT ''SymAlgReal
        )
    genArgs Pred
mode TheoryToUnify
UFP = do
      eb <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"eb"
      sb <- newName "sb"
      let ebType = Name -> Pred
VarT Name
eb
      let sbType = Name -> Pred
VarT Name
sb
      return
        ( [kindedTV eb (ConT ''Nat), kindedTV sb (ConT ''Nat)],
          [AppT (AppT (ConT ''ValidFP) ebType) sbType],
          AppT (AppT (AppT (ConT ''GetFP) mode) ebType) sbType,
          AppT (AppT (ConT ''FP) ebType) sbType,
          AppT (AppT (ConT ''SymFP) ebType) sbType
        )
    genArgs Pred
_ UFun {} = String -> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"UFun cannot be nested."