{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Avoid lambda" #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.FunInstanceGen
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.FunInstanceGen
  ( supportedPrimFun,
    supportedPrimFunUpTo,
  )
where

import qualified Data.SBV as SBV
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( IsSymbolKind,
    SupportedNonFuncPrim,
    SupportedPrim
      ( castTypedSymbol,
        conSBVTerm,
        defaultValue,
        funcDummyConstraint,
        parseSMTModelResult,
        pevalDistinctTerm,
        pevalEqTerm,
        pevalITETerm,
        sameCon,
        sbvDistinct,
        sbvEq,
        symSBVName,
        symSBVTerm,
        withPrim
      ),
    TypedSymbol (unTypedSymbol),
    decideSymbolKind,
    translateTypeError,
    typedAnySymbol,
    withNonFuncPrim,
  )
import Language.Haskell.TH
  ( Cxt,
    Dec (InstanceD),
    DecsQ,
    Exp,
    ExpQ,
    Name,
    Overlap (Overlapping),
    Q,
    Type,
    TypeQ,
    forallT,
    lamE,
    newName,
    sigD,
    stringE,
    varE,
    varP,
    varT,
  )
import Language.Haskell.TH.Datatype.TyVarBndr
  ( plainTVInferred,
    plainTVSpecified,
  )
import Type.Reflection (TypeRep, typeRep, type (:~~:) (HRefl))

instanceWithOverlapDescD ::
  Maybe Overlap -> Q Cxt -> Q Type -> [DecsQ] -> DecsQ
instanceWithOverlapDescD :: Maybe Overlap -> Q Cxt -> Q Type -> [DecsQ] -> DecsQ
instanceWithOverlapDescD Maybe Overlap
o Q Cxt
ctxts Q Type
ty [DecsQ]
descs = do
  ctxts1 <- Q Cxt
ctxts
  descs1 <- sequence descs
  ty1 <- ty
  return [InstanceD o ctxts1 ty1 (concat descs1)]

-- | Generate an instance of 'SupportedPrim' for a function with a given number
-- of arguments.
supportedPrimFun ::
  ExpQ ->
  ExpQ ->
  ExpQ ->
  ([TypeQ] -> ExpQ) ->
  String ->
  String ->
  Name ->
  Int ->
  DecsQ
supportedPrimFun :: ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFun
  ExpQ
dv
  ExpQ
ite
  ExpQ
parse
  [Q Type] -> ExpQ
consbv
  String
funNameInError
  String
funNamePrefix
  Name
funTypeName
  Int
numArg = do
    names <- (Int -> Q Name) -> [Int] -> Q [Name]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName (String -> Q Name) -> (Int -> String) -> Int -> Q Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"a" String -> String -> String
forall a. Semigroup a => a -> a -> a
<>) (String -> String) -> (Int -> String) -> Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show) [Int
0 .. Int
numArg Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    let tyVars = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> [Name] -> [Q Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
names
    knd <- newName "knd"
    knd' <- newName "knd'"
    let kndty = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
knd
    let knd'ty = Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
knd'
    instanceWithOverlapDescD
      (if numArg == 2 then Nothing else Just Overlapping)
      (constraints tyVars)
      [t|SupportedPrim $(funType tyVars)|]
      ( [ [d|$(varP 'sameCon) = (==)|],
          [d|$(varP 'defaultValue) = $dv|],
          [d|$(varP 'pevalITETerm) = $ite|],
          [d|
            $(varP 'pevalEqTerm) =
              $( translateError
                   tyVars
                   "does not supported equality comparison."
               )
            |],
          [d|
            $(varP 'pevalDistinctTerm) =
              $( translateError
                   tyVars
                   "does not supported equality comparison."
               )
            |],
          [d|
            $(varP 'conSBVTerm) = $(consbv tyVars)
            |],
          -- \$( translateError
          --      tyVars
          --      ( "must have already been partially evaluated away before "
          --          <> "reaching this point."
          --      )
          --  )

          [d|
            $(varP 'symSBVName) = \_ num ->
              $(stringE $ funNamePrefix <> show numArg <> "_") <> show num
            |],
          [d|
            $(varP 'symSBVTerm) = \r ->
              withPrim @($(funType tyVars)) $ return $ SBV.uninterpret r
            |],
          [d|$(varP 'withPrim) = $(withPrims tyVars)|],
          [d|
            $(varP 'sbvEq) =
              $( translateError
                   tyVars
                   "does not support equality comparison."
               )
            |],
          [d|
            $(varP 'sbvDistinct) =
              $( translateError
                   tyVars
                   "does not support equality comparison."
               )
            |],
          [d|$(varP 'parseSMTModelResult) = $parse|],
          (: [])
            <$> sigD
              'castTypedSymbol
              ( forallT
                  [plainTVInferred knd, plainTVSpecified knd']
                  ((: []) <$> [t|IsSymbolKind $knd'ty|])
                  [t|
                    TypedSymbol $kndty $(funType tyVars) ->
                    Maybe (TypedSymbol $knd'ty $(funType tyVars))
                    |]
              ),
          [d|
            $(varP 'castTypedSymbol) = \sym ->
              case decideSymbolKind @($knd'ty) of
                Left HRefl -> Nothing
                Right HRefl -> Just $ typedAnySymbol $ unTypedSymbol sym
            |],
          ( if numArg == 2
              then
                [d|
                  $(varP 'funcDummyConstraint) = \f ->
                    withPrim @($(funType tyVars)) $
                      withNonFuncPrim @($(last tyVars)) $ do
                        f (conSBVTerm (defaultValue :: $(head tyVars)))
                          SBV..== f
                            (conSBVTerm (defaultValue :: $(head tyVars)))
                  |]
              else
                [d|
                  $(varP 'funcDummyConstraint) = \f ->
                    withNonFuncPrim @($(head tyVars)) $
                      funcDummyConstraint @($(funType $ tail tyVars))
                        (f (conSBVTerm (defaultValue :: $(head tyVars))))
                  |]
          )
        ]
      )
    where
      translateError :: [Q Type] -> String -> ExpQ
translateError [Q Type]
tyVars String
finalMsg =
        [|
          translateTypeError
            ( Just
                $( String -> ExpQ
forall (m :: * -> *). Quote m => String -> m Exp
stringE (String -> ExpQ) -> String -> ExpQ
forall a b. (a -> b) -> a -> b
$
                     String
"BUG. Please send a bug report. "
                       String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
funNameInError
                       String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" "
                       String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
finalMsg
                 )
            )
            (typeRep :: TypeRep $([Q Type] -> Q Type
funType [Q Type]
tyVars))
          |]

      constraints :: [Q Type] -> Q Cxt
constraints =
        ([Cxt] -> Cxt) -> Q [Cxt] -> Q Cxt
forall a b. (a -> b) -> Q a -> Q b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Cxt] -> Cxt
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Q [Cxt] -> Q Cxt) -> ([Q Type] -> Q [Cxt]) -> [Q Type] -> Q Cxt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Q Type -> Q Cxt) -> [Q Type] -> Q [Cxt]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Q Type
ty -> [Q Type] -> Q Cxt
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [[t|SupportedNonFuncPrim $Q Type
ty|]])
      funType :: [Q Type] -> Q Type
funType =
        (Q Type -> Q Type -> Q Type) -> [Q Type] -> Q Type
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (\Q Type
fty Q Type
ty -> [t|$(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
funTypeName) $Q Type
ty $Q Type
fty|]) ([Q Type] -> Q Type)
-> ([Q Type] -> [Q Type]) -> [Q Type] -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Q Type] -> [Q Type]
forall a. [a] -> [a]
reverse
      withPrims :: [Q Type] -> Q Exp
      withPrims :: [Q Type] -> ExpQ
withPrims [Q Type]
tyVars = do
        r <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"r"
        lamE [varP r] $
          foldr
            (\Q Type
ty ExpQ
r -> [|withNonFuncPrim @($Q Type
ty) $ExpQ
r|])
            (varE r)
            tyVars

-- | Generate instances of 'SupportedPrim' for functions with up to a given
-- number of arguments.
supportedPrimFunUpTo ::
  ExpQ -> ExpQ -> ExpQ -> ([TypeQ] -> ExpQ) -> String -> String -> Name -> Int -> DecsQ
supportedPrimFunUpTo :: ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFunUpTo
  ExpQ
dv
  ExpQ
ite
  ExpQ
parse
  [Q Type] -> ExpQ
consbv
  String
funNameInError
  String
funNamePrefix
  Name
funTypeName
  Int
numArg =
    [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
      ([[Dec]] -> [Dec]) -> Q [[Dec]] -> DecsQ
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DecsQ] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
        [ ExpQ
-> ExpQ
-> ExpQ
-> ([Q Type] -> ExpQ)
-> String
-> String
-> Name
-> Int
-> DecsQ
supportedPrimFun
            ExpQ
dv
            ExpQ
ite
            ExpQ
parse
            [Q Type] -> ExpQ
consbv
            String
funNameInError
            String
funNamePrefix
            Name
funTypeName
            Int
n
        | Int
n <- [Int
2 .. Int
numArg]
        ]