{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      :   Grisette.Internal.TH.Derivation.UnifiedOpCommon
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.Derivation.UnifiedOpCommon
  ( UnaryOpUnifiedConfig (..),
    defaultUnaryOpUnifiedFun,
  )
where

import Grisette.Internal.TH.Derivation.Common (DeriveConfig (evalModeConfig))
import Grisette.Internal.TH.Derivation.UnaryOpCommon
  ( UnaryOpFunConfig (genUnaryOpFun),
  )
import Grisette.Internal.Unified.Util (withMode)
import Language.Haskell.TH
  ( Exp (VarE),
    Kind,
    Name,
    Q,
    Type (AppT, ArrowT, StarT, VarT),
    appE,
    clause,
    funD,
    newName,
    normalB,
    varE,
    varP,
  )

-- | Default implementation for the derivation rules for a unified operation.
defaultUnaryOpUnifiedFun :: [Name] -> Type -> (Type, Kind) -> Q (Maybe Exp)
defaultUnaryOpUnifiedFun :: [Name] -> Type -> (Type, Type) -> Q (Maybe Exp)
defaultUnaryOpUnifiedFun [Name]
funNames Type
modeTy (Type
ty, Type
kind) =
  case Type
kind of
    Type
StarT ->
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just
        (Exp -> Maybe Exp) -> Q Exp -> Q (Maybe Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [|
          $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> Q Exp) -> Name -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Name] -> Name
forall a. HasCallStack => [a] -> a
head [Name]
funNames) @($(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
modeTy))
            @($(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty))
          |]
    AppT (AppT Type
ArrowT Type
StarT) Type
StarT ->
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just
        (Exp -> Maybe Exp) -> Q Exp -> Q (Maybe Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [|
          $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> Q Exp) -> Name -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
1) @($(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
modeTy))
            @($(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty))
          |]
    AppT (AppT (AppT Type
ArrowT Type
StarT) Type
StarT) Type
StarT ->
      Exp -> Maybe Exp
forall a. a -> Maybe a
Just
        (Exp -> Maybe Exp) -> Q Exp -> Q (Maybe Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [|
          $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> Q Exp) -> Name -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
2) @($(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
modeTy))
            @($(Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty))
          |]
    Type
_ -> Maybe Exp -> Q (Maybe Exp)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Exp
forall a. Maybe a
Nothing

-- | Configuration for the derivation rules for a unified operation.
newtype UnaryOpUnifiedConfig = UnaryOpUnifiedConfig
  {UnaryOpUnifiedConfig -> Type -> (Type, Type) -> Q (Maybe Exp)
unifiedFun :: Type -> (Type, Kind) -> Q (Maybe Exp)}

instance UnaryOpFunConfig UnaryOpUnifiedConfig where
  genUnaryOpFun :: DeriveConfig
-> UnaryOpUnifiedConfig
-> [Name]
-> Int
-> [(Type, Type)]
-> [(Type, Type)]
-> [(Type, Type)]
-> (Name -> Bool)
-> [ConstructorInfo]
-> Q Dec
genUnaryOpFun
    DeriveConfig
deriveConfig
    (UnaryOpUnifiedConfig {Type -> (Type, Type) -> Q (Maybe Exp)
unifiedFun :: UnaryOpUnifiedConfig -> Type -> (Type, Type) -> Q (Maybe Exp)
unifiedFun :: Type -> (Type, Type) -> Q (Maybe Exp)
..})
    [Name]
funNames
    Int
n
    [(Type, Type)]
extraVars
    [(Type, Type)]
keptTypes
    [(Type, Type)]
_
    Name -> Bool
isVarUsedInFields
    [ConstructorInfo]
_ = do
      modeTy <- case DeriveConfig -> [(Int, EvalModeConfig)]
evalModeConfig DeriveConfig
deriveConfig of
        [] -> Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ (Type, Type) -> Type
forall a b. (a, b) -> a
fst ((Type, Type) -> Type) -> (Type, Type) -> Type
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> (Type, Type)
forall a. HasCallStack => [a] -> a
head [(Type, Type)]
extraVars
        [(Int
i, EvalModeConfig
_)] -> Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ (Type, Type) -> Type
forall a b. (a, b) -> a
fst ((Type, Type) -> Type) -> (Type, Type) -> Type
forall a b. (a -> b) -> a -> b
$ [(Type, Type)]
keptTypes [(Type, Type)] -> Int -> (Type, Type)
forall a. HasCallStack => [a] -> Int -> a
!! Int
i
        [(Int, EvalModeConfig)]
_ -> String -> Q Type
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unified classes does not support multiple evaluation modes"
      let isTypeUsedInFields (VarT Name
nm) = Name -> Bool
isVarUsedInFields Name
nm
          isTypeUsedInFields Type
_ = Bool
False
      exprs <-
        traverse (unifiedFun modeTy) $
          filter (isTypeUsedInFields . fst) keptTypes
      rVar <- newName "r"
      let rf =
            (Q Exp -> Maybe Exp -> Q Exp) -> Q Exp -> [Maybe Exp] -> Q Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \Q Exp
exp Maybe Exp
nextFun -> case Maybe Exp
nextFun of
                  Maybe Exp
Nothing -> Q Exp
exp
                  Just Exp
fun -> Q Exp -> Q Exp -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> m Exp -> m Exp
appE (Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
fun) Q Exp
exp
              )
              (Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Name -> Exp
VarE Name
rVar)
              [Maybe Exp]
exprs
      let instanceFunName = [Name]
funNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
      funD
        instanceFunName
        [ clause
            [varP rVar]
            ( normalB
                [|
                  withMode @($(return modeTy)) $(rf) $(rf)
                  |]
            )
            []
        ]