{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Grisette.Internal.Unified.UnifiedFun
( UnifiedFunConstraint,
UnifiedFun (..),
unifiedFunInstanceName,
genUnifiedFunInstance,
GetFun2,
GetFun3,
GetFun4,
GetFun5,
GetFun6,
GetFun7,
GetFun8,
)
where
#if MIN_VERSION_base(4,20,0)
#else
import Data.Foldable (Foldable (foldl'))
#endif
import Control.DeepSeq (NFData)
import Data.Binary (Binary)
import Data.Bytes.Serial (Serial)
import Data.Hashable (Hashable)
import qualified Data.Kind
import Data.Serialize (Serialize)
import Data.Typeable (Typeable)
import GHC.TypeLits (KnownNat, Nat, type (<=))
import Grisette.Internal.Core.Data.Class.EvalSym (EvalSym)
import Grisette.Internal.Core.Data.Class.ExtractSym (ExtractSym)
import Grisette.Internal.Core.Data.Class.Function (Apply (FunType), Function)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.PPrint (PPrint)
import Grisette.Internal.Core.Data.Class.SubstSym (SubstSym)
import Grisette.Internal.Core.Data.Class.ToCon (ToCon)
import Grisette.Internal.Core.Data.Class.ToSym (ToSym)
import Grisette.Internal.SymPrim.AlgReal (AlgReal)
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.FP (FP, ValidFP)
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal)
import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN)
import Grisette.Internal.SymPrim.SymBool (SymBool)
import Grisette.Internal.SymPrim.SymFP (SymFP)
import Grisette.Internal.SymPrim.SymInteger (SymInteger)
import Grisette.Internal.SymPrim.SymTabularFun (type (=~>))
import Grisette.Internal.SymPrim.TabularFun (type (=->))
import Grisette.Internal.Unified.EvalModeTag (EvalModeTag (C, S))
import Grisette.Internal.Unified.Theories
( TheoryToUnify (UAlgReal, UBool, UFP, UFun, UIntN, UInteger, UWordN),
)
import Grisette.Internal.Unified.UnifiedAlgReal (GetAlgReal)
import Grisette.Internal.Unified.UnifiedBV (UnifiedBVImpl (GetIntN, GetWordN))
import Grisette.Internal.Unified.UnifiedBool (UnifiedBool (GetBool))
import Grisette.Internal.Unified.UnifiedFP (UnifiedFPImpl (GetFP))
import Grisette.Internal.Unified.UnifiedInteger (GetInteger)
import Language.Haskell.TH
( DecsQ,
Pred,
Q,
TyLit (NumTyLit),
Type (AppT, ConT, ForallT, LitT, VarT),
appT,
classD,
conT,
instanceD,
mkName,
newName,
promotedT,
varT,
)
import qualified Language.Haskell.TH
import Language.Haskell.TH.Datatype.TyVarBndr
( kindedTV,
mapTVFlag,
specifiedSpec,
tvName,
)
import Language.Haskell.TH.Syntax (Lift)
#if MIN_VERSION_template_haskell(2,21,0)
type TyVarBndrVis = Language.Haskell.TH.TyVarBndrVis
#elif MIN_VERSION_template_haskell(2,17,0)
type TyVarBndrVis = Language.Haskell.TH.TyVarBndr ()
#else
type TyVarBndrVis = Language.Haskell.TH.TyVarBndr
#endif
class UnifiedFun (mode :: EvalModeTag) where
type
GetFun mode =
(fun :: Data.Kind.Type -> Data.Kind.Type -> Data.Kind.Type) | fun -> mode
instance UnifiedFun 'C where
type GetFun 'C = (=->)
instance UnifiedFun 'S where
type GetFun 'S = (=~>)
type GetFun2 mode a b = GetFun mode a b
type GetFun3 mode a b c = GetFun mode a (GetFun mode b c)
type GetFun4 mode a b c d = GetFun mode a (GetFun mode b (GetFun mode c d))
type GetFun5 mode a b c d e =
GetFun mode a (GetFun mode b (GetFun mode c (GetFun mode d e)))
type GetFun6 mode a b c d e f =
GetFun
mode
a
(GetFun mode b (GetFun mode c (GetFun mode d (GetFun mode e f))))
type GetFun7 mode a b c d e f g =
GetFun
mode
a
( GetFun
mode
b
(GetFun mode c (GetFun mode d (GetFun mode e (GetFun mode f g))))
)
type GetFun8 mode a b c d e f g h =
GetFun
mode
a
( GetFun
mode
b
( GetFun
mode
c
(GetFun mode d (GetFun mode e (GetFun mode f (GetFun mode g h))))
)
)
type UnifiedFunConstraint mode a b ca cb sa sb =
( Show (GetFun mode a b),
Binary (GetFun mode a b),
Serial (GetFun mode a b),
Serialize (GetFun mode a b),
NFData (GetFun mode a b),
Eq (GetFun mode a b),
EvalSym (GetFun mode a b),
ExtractSym (GetFun mode a b),
Mergeable (GetFun mode a b),
PPrint (GetFun mode a b),
SubstSym (GetFun mode a b),
Hashable (GetFun mode a b),
Lift (GetFun mode a b),
Typeable (GetFun mode a b),
ToCon (GetFun mode a b) (ca =-> cb),
ToCon (sa =~> sb) (GetFun mode a b),
ToSym (GetFun mode a b) (sa =~> sb),
ToSym (ca =-> cb) (GetFun mode a b),
Function (GetFun mode a b) a b,
Apply (GetFun mode a b),
FunType (GetFun mode a b) ~ (a -> b)
)
genInnerUnifiedFunInstance ::
String ->
TyVarBndrVis ->
[Pred] ->
[TyVarBndrVis] ->
[(Type, Type, Type)] ->
DecsQ
genInnerUnifiedFunInstance :: String
-> TyVarBndrVis
-> [Pred]
-> [TyVarBndrVis]
-> [(Pred, Pred, Pred)]
-> DecsQ
genInnerUnifiedFunInstance String
nm TyVarBndrVis
mode [Pred]
preds [TyVarBndrVis]
bndrs [(Pred, Pred, Pred)]
tys = do
x <- Q [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [Q Dec] -> Q Dec
forall (m :: * -> *).
Quote m =>
m [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [m Dec] -> m Dec
classD ([(Pred, Pred, Pred)] -> Q [Pred]
goPred [(Pred, Pred, Pred)]
tys) (String -> Name
mkName String
nm) (TyVarBndrVis
mode TyVarBndrVis -> [TyVarBndrVis] -> [TyVarBndrVis]
forall a. a -> [a] -> [a]
: [TyVarBndrVis]
bndrs) [] []
dc <-
instanceD
(return preds)
(applyTypeList (promotedT 'C : additionalTypes))
[]
ds <-
instanceD
(return preds)
(applyTypeList (promotedT 'S : additionalTypes))
[]
return [x, dc, ds]
where
additionalTypes :: [Q Pred]
additionalTypes = (Name -> Q Pred
forall (m :: * -> *). Quote m => Name -> m Pred
varT (Name -> Q Pred)
-> (TyVarBndrVis -> Name) -> TyVarBndrVis -> Q Pred
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndrVis -> Name
forall flag. TyVarBndr_ flag -> Name
tvName) (TyVarBndrVis -> Q Pred) -> [TyVarBndrVis] -> [Q Pred]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndrVis]
bndrs
applyTypeList :: [Q Pred] -> Q Pred
applyTypeList = (Q Pred -> Q Pred -> Q Pred) -> Q Pred -> [Q Pred] -> Q Pred
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Q Pred -> Q Pred -> Q Pred
forall (m :: * -> *). Quote m => m Pred -> m Pred -> m Pred
appT (Name -> Q Pred
forall (m :: * -> *). Quote m => Name -> m Pred
conT (String -> Name
mkName String
nm))
goPred :: [(Type, Type, Type)] -> Q [Pred]
goPred :: [(Pred, Pred, Pred)] -> Q [Pred]
goPred [] = String -> Q [Pred]
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Empty list of function types, at least 2."
goPred [(Pred, Pred, Pred)
_] = [Pred] -> Q [Pred]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return []
goPred ((Pred, Pred, Pred)
x : [(Pred, Pred, Pred)]
xs) = do
p1 <- (Pred, Pred, Pred) -> [(Pred, Pred, Pred)] -> Q Pred
pred (Pred, Pred, Pred)
x [(Pred, Pred, Pred)]
xs
pr <- goPred xs
return $ p1 : pr
listTys :: [(Type, Type, Type)] -> Q (Type, Type, Type)
listTys :: [(Pred, Pred, Pred)] -> Q (Pred, Pred, Pred)
listTys [] = String -> Q (Pred, Pred, Pred)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Should not happen"
listTys [(Pred
u, Pred
c, Pred
s)] = (Pred, Pred, Pred) -> Q (Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Pred
u, Pred
c, Pred
s)
listTys ((Pred
u, Pred
c, Pred
s) : [(Pred, Pred, Pred)]
xs) = do
(u', c', s') <- [(Pred, Pred, Pred)] -> Q (Pred, Pred, Pred)
listTys [(Pred, Pred, Pred)]
xs
return
( AppT (AppT (AppT (ConT ''GetFun) (VarT $ tvName mode)) u) u',
AppT (AppT (ConT ''(=->)) c) c',
AppT (AppT (ConT ''(=~>)) s) s'
)
pred :: (Pred, Pred, Pred) -> [(Pred, Pred, Pred)] -> Q Pred
pred (Pred
ua, Pred
ca, Pred
sa) [(Pred, Pred, Pred)]
l = do
(ub, cb, sb) <- [(Pred, Pred, Pred)] -> Q (Pred, Pred, Pred)
listTys [(Pred, Pred, Pred)]
l
[t|
UnifiedFunConstraint
$(return (VarT $ tvName mode))
$(return ua)
$(return ub)
$(return ca)
$(return cb)
$(return sa)
$(return sb)
|]
genOuterUnifiedFunInstance ::
String -> String -> TyVarBndrVis -> [Pred] -> [TyVarBndrVis] -> DecsQ
genOuterUnifiedFunInstance :: String
-> String -> TyVarBndrVis -> [Pred] -> [TyVarBndrVis] -> DecsQ
genOuterUnifiedFunInstance String
nm String
innerName TyVarBndrVis
mode [Pred]
preds [TyVarBndrVis]
bndrs = do
let bndrs' :: [TyVarBndr_ Specificity]
bndrs' = (BndrVis -> Specificity) -> TyVarBndrVis -> TyVarBndr_ Specificity
forall flag flag'.
(flag -> flag') -> TyVarBndr_ flag -> TyVarBndr_ flag'
mapTVFlag (Specificity -> BndrVis -> Specificity
forall a b. a -> b -> a
const Specificity
specifiedSpec) (TyVarBndrVis -> TyVarBndr_ Specificity)
-> [TyVarBndrVis] -> [TyVarBndr_ Specificity]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndrVis]
bndrs
x <-
Q [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [Q Dec] -> Q Dec
forall (m :: * -> *).
Quote m =>
m [Pred] -> Name -> [TyVarBndrVis] -> [FunDep] -> [m Dec] -> m Dec
classD
( [Pred] -> Q [Pred]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
[ [TyVarBndr_ Specificity] -> [Pred] -> Pred -> Pred
ForallT [TyVarBndr_ Specificity]
bndrs' [Pred]
preds (Pred -> Pred) -> Pred -> Pred
forall a b. (a -> b) -> a -> b
$
(Pred -> Pred -> Pred) -> Pred -> [Pred] -> Pred
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Pred -> Pred -> Pred
AppT (Name -> Pred
ConT (Name -> Pred) -> Name -> Pred
forall a b. (a -> b) -> a -> b
$ String -> Name
mkName String
innerName) ([Pred] -> Pred) -> [Pred] -> Pred
forall a b. (a -> b) -> a -> b
$
Name -> Pred
VarT (Name -> Pred) -> (TyVarBndrVis -> Name) -> TyVarBndrVis -> Pred
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndrVis -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndrVis -> Pred) -> [TyVarBndrVis] -> [Pred]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TyVarBndrVis
mode TyVarBndrVis -> [TyVarBndrVis] -> [TyVarBndrVis]
forall a. a -> [a] -> [a]
: [TyVarBndrVis]
bndrs
]
)
(String -> Name
mkName String
nm)
[TyVarBndrVis
mode]
[]
[]
dc <-
instanceD
(return [])
(appT (conT $ mkName nm) (promotedT 'C))
[]
ds <-
instanceD
(return [])
(appT (conT $ mkName nm) (promotedT 'S))
[]
return [x, dc, ds]
unifiedFunInstanceName :: String -> [TheoryToUnify] -> String
unifiedFunInstanceName :: String -> [TheoryToUnify] -> String
unifiedFunInstanceName String
prefix [TheoryToUnify]
theories =
String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Fun" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((TheoryToUnify -> String) -> [TheoryToUnify] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap TheoryToUnify -> String
forall a. Show a => a -> String
show [TheoryToUnify]
theories)
genUnifiedFunInstance :: String -> [TheoryToUnify] -> DecsQ
genUnifiedFunInstance :: String -> [TheoryToUnify] -> DecsQ
genUnifiedFunInstance String
prefix [TheoryToUnify]
theories = do
modeName <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"mode"
let modeType = Name -> Pred
VarT Name
modeName
allArgs <- traverse (genArgs modeType) theories
let baseName = String -> [TheoryToUnify] -> String
unifiedFunInstanceName String
prefix [TheoryToUnify]
theories
rinner <-
genInnerUnifiedFunInstance
baseName
(kindedTV modeName (ConT ''EvalModeTag))
(concatMap (\([TyVarBndrVis]
_, [Pred]
p, Pred
_, Pred
_, Pred
_) -> [Pred]
p) allArgs)
(concatMap (\([TyVarBndrVis]
t, [Pred]
_, Pred
_, Pred
_, Pred
_) -> [TyVarBndrVis]
t) allArgs)
((\([TyVarBndrVis]
_, [Pred]
_, Pred
u, Pred
c, Pred
s) -> (Pred
u, Pred
c, Pred
s)) <$> allArgs)
router <-
if all (\([TyVarBndrVis]
bndr, [Pred]
_, Pred
_, Pred
_, Pred
_) -> [TyVarBndrVis] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TyVarBndrVis]
bndr) allArgs
then return []
else
genOuterUnifiedFunInstance
("All" ++ baseName)
baseName
(kindedTV modeName (ConT ''EvalModeTag))
(concatMap (\([TyVarBndrVis]
_, [Pred]
p, Pred
_, Pred
_, Pred
_) -> [Pred]
p) allArgs)
(concatMap (\([TyVarBndrVis]
t, [Pred]
_, Pred
_, Pred
_, Pred
_) -> [TyVarBndrVis]
t) allArgs)
return $ rinner ++ router
where
genArgs ::
Type -> TheoryToUnify -> Q ([TyVarBndrVis], [Pred], Type, Type, Type)
genArgs :: Pred
-> TheoryToUnify -> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
genArgs Pred
mode TheoryToUnify
UBool =
([TyVarBndrVis], [Pred], Pred, Pred, Pred)
-> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
( [],
[],
Pred -> Pred -> Pred
AppT (Name -> Pred
ConT ''GetBool) Pred
mode,
Name -> Pred
ConT ''Bool,
Name -> Pred
ConT ''SymBool
)
genArgs Pred
mode TheoryToUnify
UIntN = do
n <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"n"
let nType = Name -> Pred
VarT Name
n
return
( [kindedTV n (ConT ''Nat)],
[ AppT (ConT ''KnownNat) nType,
AppT (AppT (ConT ''(<=)) (LitT $ NumTyLit 1)) nType
],
AppT (AppT (ConT ''GetIntN) mode) nType,
AppT (ConT ''IntN) nType,
AppT (ConT ''SymIntN) nType
)
genArgs Pred
mode TheoryToUnify
UWordN = do
n <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"n"
let nType = Name -> Pred
VarT Name
n
return
( [kindedTV n (ConT ''Nat)],
[ AppT (ConT ''KnownNat) nType,
AppT (AppT (ConT ''(<=)) (LitT $ NumTyLit 1)) nType
],
AppT (AppT (ConT ''GetWordN) mode) nType,
AppT (ConT ''WordN) nType,
AppT (ConT ''SymWordN) nType
)
genArgs Pred
mode TheoryToUnify
UInteger =
([TyVarBndrVis], [Pred], Pred, Pred, Pred)
-> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
( [],
[],
Pred -> Pred -> Pred
AppT (Name -> Pred
ConT ''GetInteger) Pred
mode,
Name -> Pred
ConT ''Integer,
Name -> Pred
ConT ''SymInteger
)
genArgs Pred
mode TheoryToUnify
UAlgReal =
([TyVarBndrVis], [Pred], Pred, Pred, Pred)
-> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return
( [],
[],
Pred -> Pred -> Pred
AppT (Name -> Pred
ConT ''GetAlgReal) Pred
mode,
Name -> Pred
ConT ''AlgReal,
Name -> Pred
ConT ''SymAlgReal
)
genArgs Pred
mode TheoryToUnify
UFP = do
eb <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"eb"
sb <- newName "sb"
let ebType = Name -> Pred
VarT Name
eb
let sbType = Name -> Pred
VarT Name
sb
return
( [kindedTV eb (ConT ''Nat), kindedTV sb (ConT ''Nat)],
[AppT (AppT (ConT ''ValidFP) ebType) sbType],
AppT (AppT (AppT (ConT ''GetFP) mode) ebType) sbType,
AppT (AppT (ConT ''FP) ebType) sbType,
AppT (AppT (ConT ''SymFP) ebType) sbType
)
genArgs Pred
_ UFun {} = String -> Q ([TyVarBndrVis], [Pred], Pred, Pred, Pred)
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"UFun cannot be nested."