{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Avoid lambda" #-}
module Grisette.Internal.SymPrim.FunInstanceGen
( supportedPrimFun,
supportedPrimFunUpTo,
)
where
import qualified Data.SBV as SBV
import Grisette.Internal.SymPrim.Prim.Internal.Term
( IsSymbolKind,
SupportedNonFuncPrim,
SupportedPrim
( castTypedSymbol,
conSBVTerm,
defaultValue,
funcDummyConstraint,
parseSMTModelResult,
pevalDistinctTerm,
pevalEqTerm,
pevalITETerm,
sameCon,
sbvDistinct,
sbvEq,
symSBVName,
symSBVTerm,
withPrim
),
TypedSymbol (unTypedSymbol),
decideSymbolKind,
translateTypeError,
typedAnySymbol,
withNonFuncPrim,
)
import Language.Haskell.TH
( Cxt,
Dec (InstanceD),
DecsQ,
Exp,
ExpQ,
Name,
Overlap (Overlapping),
Q,
Type,
TypeQ,
forallT,
lamE,
newName,
sigD,
stringE,
varE,
varP,
varT,
)
import Language.Haskell.TH.Datatype.TyVarBndr
( plainTVInferred,
plainTVSpecified,
)
import Type.Reflection (TypeRep, typeRep, type (:~~:) (HRefl))
instanceWithOverlapDescD ::
Maybe Overlap -> Q Cxt -> Q Type -> [DecsQ] -> DecsQ
instanceWithOverlapDescD :: Maybe Overlap -> Q Cxt -> Q Type -> [DecsQ] -> DecsQ
instanceWithOverlapDescD Maybe Overlap
o Q Cxt
ctxts Q Type
ty [DecsQ]
descs = do
ctxts1 <- Q Cxt
ctxts
descs1 <- sequence descs
ty1 <- ty
return [InstanceD o ctxts1 ty1 (concat descs1)]
supportedPrimFun ::
ExpQ ->
ExpQ ->
ExpQ ->
([TypeQ] -> ExpQ) ->
String ->
String ->
Name ->
Int ->
DecsQ
supportedPrimFun :: ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFun
ExpQ
dv
ExpQ
ite
ExpQ
parse
[Q Type] -> ExpQ
consbv
String
funNameInError
String
funNamePrefix
Name
funTypeName
Int
numArg = do
names <- (Int -> Q Name) -> [Int] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Q Name) -> (Int -> String) -> Int -> Q Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"a" String -> String -> String
forall a. Semigroup a => a -> a -> a
<>) (String -> String) -> (Int -> String) -> Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show) [Int
0 .. Int
numArg Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
let tyVars = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> [Name] -> [Q Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
names
knd <- newName "knd"
knd' <- newName "knd'"
let kndty = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
knd
let knd'ty = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
knd'
instanceWithOverlapDescD
(if numArg == 2 then Nothing else Just Overlapping)
(constraints tyVars)
[t|SupportedPrim $(funType tyVars)|]
( [ [d|$(varP 'sameCon) = (==)|],
[d|$(varP 'defaultValue) = $dv|],
[d|$(varP 'pevalITETerm) = $ite|],
[d|
$(varP 'pevalEqTerm) =
$( translateError
tyVars
"does not supported equality comparison."
)
|],
[d|
$(varP 'pevalDistinctTerm) =
$( translateError
tyVars
"does not supported equality comparison."
)
|],
[d|
$(varP 'conSBVTerm) = $(consbv tyVars)
|],
[d|
$(varP 'symSBVName) = \_ num ->
$(stringE $ funNamePrefix <> show numArg <> "_") <> show num
|],
[d|
$(varP 'symSBVTerm) = \r ->
withPrim @($(funType tyVars)) $ return $ SBV.uninterpret r
|],
[d|$(varP 'withPrim) = $(withPrims tyVars)|],
[d|
$(varP 'sbvEq) =
$( translateError
tyVars
"does not support equality comparison."
)
|],
[d|
$(varP 'sbvDistinct) =
$( translateError
tyVars
"does not support equality comparison."
)
|],
[d|$(varP 'parseSMTModelResult) = $parse|],
(: [])
<$> sigD
'castTypedSymbol
( forallT
[plainTVInferred knd, plainTVSpecified knd']
((: []) <$> [t|IsSymbolKind $knd'ty|])
[t|
TypedSymbol $kndty $(funType tyVars) ->
Maybe (TypedSymbol $knd'ty $(funType tyVars))
|]
),
[d|
$(varP 'castTypedSymbol) = \sym ->
case decideSymbolKind @($knd'ty) of
Left HRefl -> Nothing
Right HRefl -> Just $ typedAnySymbol $ unTypedSymbol sym
|],
( if numArg == 2
then
[d|
$(varP 'funcDummyConstraint) = \f ->
withPrim @($(funType tyVars)) $
withNonFuncPrim @($(last tyVars)) $ do
f (conSBVTerm (defaultValue :: $(head tyVars)))
SBV..== f
(conSBVTerm (defaultValue :: $(head tyVars)))
|]
else
[d|
$(varP 'funcDummyConstraint) = \f ->
withNonFuncPrim @($(head tyVars)) $
funcDummyConstraint @($(funType $ tail tyVars))
(f (conSBVTerm (defaultValue :: $(head tyVars))))
|]
)
]
)
where
translateError :: [Q Type] -> String -> ExpQ
translateError [Q Type]
tyVars String
finalMsg =
[|
translateTypeError
( Just
$( String -> ExpQ
forall (m :: * -> *). Quote m => String -> m Exp
stringE (String -> ExpQ) -> String -> ExpQ
forall a b. (a -> b) -> a -> b
$
String
"BUG. Please send a bug report. "
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
funNameInError
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" "
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
finalMsg
)
)
(typeRep :: TypeRep $([Q Type] -> Q Type
funType [Q Type]
tyVars))
|]
constraints :: [Q Type] -> Q Cxt
constraints =
([Cxt] -> Cxt) -> Q [Cxt] -> Q Cxt
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Cxt] -> Cxt
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Q [Cxt] -> Q Cxt) -> ([Q Type] -> Q [Cxt]) -> [Q Type] -> Q Cxt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Q Type -> Q Cxt) -> [Q Type] -> Q [Cxt]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Q Type
ty -> [Q Type] -> Q Cxt
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [[t|SupportedNonFuncPrim $Q Type
ty|]])
funType :: [Q Type] -> Q Type
funType =
(Q Type -> Q Type -> Q Type) -> [Q Type] -> Q Type
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\Q Type
fty Q Type
ty -> [t|$(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
funTypeName) $Q Type
ty $Q Type
fty|]) ([Q Type] -> Q Type)
-> ([Q Type] -> [Q Type]) -> [Q Type] -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Q Type] -> [Q Type]
forall a. [a] -> [a]
reverse
withPrims :: [Q Type] -> Q Exp
withPrims :: [Q Type] -> ExpQ
withPrims [Q Type]
tyVars = do
r <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"r"
lamE [varP r] $
foldr
(\Q Type
ty ExpQ
r -> [|withNonFuncPrim @($Q Type
ty) $ExpQ
r|])
(varE r)
tyVars
supportedPrimFunUpTo ::
ExpQ -> ExpQ -> ExpQ -> ([TypeQ] -> ExpQ) -> String -> String -> Name -> Int -> DecsQ
supportedPrimFunUpTo :: ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFunUpTo
ExpQ
dv
ExpQ
ite
ExpQ
parse
[Q Type] -> ExpQ
consbv
String
funNameInError
String
funNamePrefix
Name
funTypeName
Int
numArg =
[[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DecsQ] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
[ ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFun
ExpQ
dv
ExpQ
ite
ExpQ
parse
[Q Type] -> ExpQ
consbv
String
funNameInError
String
funNamePrefix
Name
funTypeName
Int
n
| Int
n <- [Int
2 .. Int
numArg]
]