{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}

-- |
-- Module      :   Grisette.Internal.TH.Derivation.BinaryOpCommon
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.Derivation.BinaryOpCommon
  ( BinaryOpClassConfig (..),
    BinaryOpFieldConfig (..),
    FieldFunExp,
    defaultFieldFunExp,
    genBinaryOpClause,
    genBinaryOpClass,
  )
where

import Control.Monad (replicateM, unless, when, zipWithM)
import Control.Monad.Identity (IdentityT)
import qualified Data.List as List
import qualified Data.Map as M
import Data.Maybe (catMaybes, mapMaybe)
import Data.Proxy (Proxy (Proxy))
import qualified Data.Set as S
import Grisette.Internal.TH.Derivation.Common
  ( CheckArgsResult
      ( argVars,
        constructors,
        keptVars
      ),
    DeriveConfig (unconstrainedPositions),
    checkArgs,
    ctxForVar,
    evalModeSpecializeList,
    extraConstraint,
    freshenCheckArgsResult,
    isVarUsedInFields,
    specializeResult,
  )
import Language.Haskell.TH
  ( Clause,
    Dec (FunD, InstanceD),
    Exp (VarE),
    Kind,
    Name,
    Pat (VarP, WildP),
    Q,
    Type (AppT, ConT, VarT),
    clause,
    conP,
    funD,
    nameBase,
    newName,
    normalB,
    recP,
    sigP,
    varE,
    varP,
    varT,
    wildP,
  )
import Language.Haskell.TH.Datatype
  ( ConstructorInfo (constructorFields, constructorName, constructorVars),
    TypeSubstitution (freeVariables),
    resolveTypeSynonyms,
    tvName,
  )
import Type.Reflection
  ( TypeRep,
    eqTypeRep,
    someTypeRep,
    typeRep,
    type (:~~:) (HRefl),
  )

-- | Type of field function expression generator.
type FieldFunExp = M.Map Name Name -> Type -> Q Exp

-- | Default field function expression generator.
defaultFieldFunExp :: [Name] -> FieldFunExp
defaultFieldFunExp :: [Name] -> FieldFunExp
defaultFieldFunExp [Name]
binaryOpFunNames Map Name Name
argToFunPat = Type -> Q Exp
forall {m :: * -> *}. (MonadFail m, Quote m) => Type -> m Exp
go
  where
    go :: Type -> m Exp
go Type
ty = do
      let allArgNames :: Set Name
allArgNames = Map Name Name -> Set Name
forall k a. Map k a -> Set k
M.keysSet Map Name Name
argToFunPat
      let typeHasNoArg :: a -> Bool
typeHasNoArg a
ty =
            [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([a] -> [Name]
forall a. TypeSubstitution a => a -> [Name]
freeVariables [a
ty])
              Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set Name
allArgNames
              Set Name -> Set Name -> Bool
forall a. Eq a => a -> a -> Bool
== Set Name
forall a. Set a
S.empty
      let fun0 :: m Exp
fun0 = Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> m Exp) -> Name -> m Exp
forall a b. (a -> b) -> a -> b
$ [Name] -> Name
forall a. HasCallStack => [a] -> a
head [Name]
binaryOpFunNames
          fun1 :: Type -> m Exp
fun1 Type
b = [|$(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> m Exp) -> Name -> m Exp
forall a b. (a -> b) -> a -> b
$ [Name]
binaryOpFunNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
1) $(Type -> m Exp
go Type
b)|]
          fun2 :: Type -> Type -> m Exp
fun2 Type
b Type
c = [|$(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> m Exp) -> Name -> m Exp
forall a b. (a -> b) -> a -> b
$ [Name]
binaryOpFunNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
2) $(Type -> m Exp
go Type
b) $(Type -> m Exp
go Type
c)|]
          fun3 :: Type -> Type -> Type -> m Exp
fun3 Type
b Type
c Type
d =
            [|$(Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE (Name -> m Exp) -> Name -> m Exp
forall a b. (a -> b) -> a -> b
$ [Name]
binaryOpFunNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
3) $(Type -> m Exp
go Type
b) $(Type -> m Exp
go Type
c) $(Type -> m Exp
go Type
d)|]
      case Type
ty of
        AppT (AppT (AppT (VarT Name
_) Type
b) Type
c) Type
d -> Type -> Type -> Type -> m Exp
fun3 Type
b Type
c Type
d
        AppT (AppT (VarT Name
_) Type
b) Type
c -> Type -> Type -> m Exp
fun2 Type
b Type
c
        AppT (VarT Name
_) Type
b -> Type -> m Exp
fun1 Type
b
        Type
_ | Type -> Bool
forall {a}. TypeSubstitution a => a -> Bool
typeHasNoArg Type
ty -> m Exp
fun0
        AppT Type
a Type
b | Type -> Bool
forall {a}. TypeSubstitution a => a -> Bool
typeHasNoArg Type
a -> Type -> m Exp
fun1 Type
b
        AppT (AppT Type
a Type
b) Type
c | Type -> Bool
forall {a}. TypeSubstitution a => a -> Bool
typeHasNoArg Type
a -> Type -> Type -> m Exp
fun2 Type
b Type
c
        AppT (AppT (AppT Type
a Type
b) Type
c) Type
d | Type -> Bool
forall {a}. TypeSubstitution a => a -> Bool
typeHasNoArg Type
a -> Type -> Type -> Type -> m Exp
fun3 Type
b Type
c Type
d
        VarT Name
nm -> case Name -> Map Name Name -> Maybe Name
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
nm Map Name Name
argToFunPat of
          Just Name
pname -> Name -> m Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
pname
          Maybe Name
_ -> String -> m Exp
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m Exp) -> String -> m Exp
forall a b. (a -> b) -> a -> b
$ String
"defaultFieldFunExp: unsupported type: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Type -> String
forall a. Show a => a -> String
show Type
ty
        Type
_ -> String -> m Exp
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m Exp) -> String -> m Exp
forall a b. (a -> b) -> a -> b
$ String
"defaultFieldFunExp: unsupported type: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Type -> String
forall a. Show a => a -> String
show Type
ty

funPatAndExps ::
  FieldFunExp ->
  [(Type, Kind)] ->
  [Type] ->
  Q ([Pat], [Exp])
funPatAndExps :: FieldFunExp -> [(Type, Type)] -> [Type] -> Q ([Pat], [Exp])
funPatAndExps FieldFunExp
fieldFunExpGen [(Type, Type)]
argTypes [Type]
fields = do
  let usedArgs :: Set Name
usedArgs = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ [Type] -> [Name]
forall a. TypeSubstitution a => a -> [Name]
freeVariables [Type]
fields
  args <-
    ((Type, Type) -> Q (Name, Maybe Name))
-> [(Type, Type)] -> Q [(Name, Maybe Name)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
      ( \(Type
ty, Type
_) ->
          case Type
ty of
            VarT Name
nm ->
              if Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
nm Set Name
usedArgs
                then do
                  pname <- String -> Q Name
forall (m :: * -> *). Quote m => String -> m Name
newName String
"p"
                  return (nm, Just pname)
                else (Name, Maybe Name) -> Q (Name, Maybe Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ('undefined, Maybe Name
forall a. Maybe a
Nothing)
            Type
_ -> (Name, Maybe Name) -> Q (Name, Maybe Name)
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ('undefined, Maybe Name
forall a. Maybe a
Nothing)
      )
      [(Type, Type)]
argTypes
  let argToFunPat =
        [(Name, Name)] -> Map Name Name
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Name)] -> Map Name Name)
-> [(Name, Name)] -> Map Name Name
forall a b. (a -> b) -> a -> b
$ ((Name, Maybe Name) -> Maybe (Name, Name))
-> [(Name, Maybe Name)] -> [(Name, Name)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Name
ty, Maybe Name
mpat) -> (Name -> (Name, Name)) -> Maybe Name -> Maybe (Name, Name)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name
ty,) Maybe Name
mpat) [(Name, Maybe Name)]
args
  let funPats = ((Name, Maybe Name) -> Pat) -> [(Name, Maybe Name)] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Pat -> (Name -> Pat) -> Maybe Name -> Pat
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Pat
WildP Name -> Pat
VarP (Maybe Name -> Pat)
-> ((Name, Maybe Name) -> Maybe Name) -> (Name, Maybe Name) -> Pat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Maybe Name) -> Maybe Name
forall a b. (a, b) -> b
snd) [(Name, Maybe Name)]
args
  defaultFieldFunExps <- traverse (fieldFunExpGen argToFunPat) fields
  return (funPats, defaultFieldFunExps)

-- | Configuration for a binary operation field generation on a GADT.
data BinaryOpFieldConfig = BinaryOpFieldConfig
  { BinaryOpFieldConfig -> [String]
extraPatNames :: [String],
    BinaryOpFieldConfig
-> [Exp] -> (Exp, Exp) -> Exp -> Q (Exp, [Bool])
fieldResFun :: [Exp] -> (Exp, Exp) -> Exp -> Q (Exp, [Bool]),
    BinaryOpFieldConfig -> Name -> [Exp] -> Q (Exp, [Bool])
fieldCombineFun :: Name -> [Exp] -> Q (Exp, [Bool]),
    BinaryOpFieldConfig -> Exp -> Q Exp
fieldDifferentExistentialFun :: Exp -> Q Exp,
    BinaryOpFieldConfig -> Q Exp
fieldLMatchResult :: Q Exp,
    BinaryOpFieldConfig -> Q Exp
fieldRMatchResult :: Q Exp,
    BinaryOpFieldConfig -> FieldFunExp
fieldFunExp :: FieldFunExp,
    BinaryOpFieldConfig -> [Name]
fieldFunNames :: [Name]
  }

-- | Generate a clause for a binary operation on a GADT.
genBinaryOpClause ::
  BinaryOpFieldConfig ->
  [(Type, Kind)] ->
  [(Type, Kind)] ->
  Bool ->
  ConstructorInfo ->
  ConstructorInfo ->
  Q [Clause]
genBinaryOpClause :: BinaryOpFieldConfig
-> [(Type, Type)]
-> [(Type, Type)]
-> Bool
-> ConstructorInfo
-> ConstructorInfo
-> Q [Clause]
genBinaryOpClause
  (BinaryOpFieldConfig {[String]
[Name]
Q Exp
[Exp] -> (Exp, Exp) -> Exp -> Q (Exp, [Bool])
Exp -> Q Exp
Name -> [Exp] -> Q (Exp, [Bool])
FieldFunExp
extraPatNames :: BinaryOpFieldConfig -> [String]
fieldResFun :: BinaryOpFieldConfig
-> [Exp] -> (Exp, Exp) -> Exp -> Q (Exp, [Bool])
fieldCombineFun :: BinaryOpFieldConfig -> Name -> [Exp] -> Q (Exp, [Bool])
fieldDifferentExistentialFun :: BinaryOpFieldConfig -> Exp -> Q Exp
fieldLMatchResult :: BinaryOpFieldConfig -> Q Exp
fieldRMatchResult :: BinaryOpFieldConfig -> Q Exp
fieldFunExp :: BinaryOpFieldConfig -> FieldFunExp
fieldFunNames :: BinaryOpFieldConfig -> [Name]
extraPatNames :: [String]
fieldResFun :: [Exp] -> (Exp, Exp) -> Exp -> Q (Exp, [Bool])
fieldCombineFun :: Name -> [Exp] -> Q (Exp, [Bool])
fieldDifferentExistentialFun :: Exp -> Q Exp
fieldLMatchResult :: Q Exp
fieldRMatchResult :: Q Exp
fieldFunExp :: FieldFunExp
fieldFunNames :: [Name]
..})
  [(Type, Type)]
lhsArgNewVars
  [(Type, Type)]
_rhsArgNewVars
  Bool
isLast
  ConstructorInfo
lhsConstructors
  ConstructorInfo
rhsConstructors =
    do
      lhsFields <- (Type -> Q Type) -> [Type] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Type -> Q Type
resolveTypeSynonyms ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$ ConstructorInfo -> [Type]
constructorFields ConstructorInfo
lhsConstructors
      rhsFields <- mapM resolveTypeSynonyms $ constructorFields rhsConstructors
      (funPats, defaultFieldFunExps) <-
        funPatAndExps fieldFunExp lhsArgNewVars lhsFields
      unless (null extraPatNames) $
        unless isLast $
          fail "Should not happen"
      extraPatNames <- traverse newName extraPatNames
      let extraPats = (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Pat
VarP [Name]
extraPatNames
      let extraPatExps = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Exp
VarE [Name]
extraPatNames
      lhsFieldsPatNames <- replicateM (length lhsFields) $ newName "lhsField"
      rhsFieldsPatNames <- replicateM (length rhsFields) $ newName "rhsField"
      let lhsFieldPats =
            Name -> [Q Pat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP
              (ConstructorInfo -> Name
constructorName ConstructorInfo
lhsConstructors)
              ( (Name -> Type -> Q Pat) -> [Name] -> [Type] -> [Q Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                  (\Name
nm Type
field -> Q Pat -> Q Type -> Q Pat
forall (m :: * -> *). Quote m => m Pat -> m Type -> m Pat
sigP (Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
nm) (Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
field))
                  [Name]
lhsFieldsPatNames
                  [Type]
lhsFields
              )
      let rhsFieldPats =
            Name -> [Q Pat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP
              (ConstructorInfo -> Name
constructorName ConstructorInfo
rhsConstructors)
              ( (Name -> Type -> Q Pat) -> [Name] -> [Type] -> [Q Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                  (\Name
nm Type
field -> Q Pat -> Q Type -> Q Pat
forall (m :: * -> *). Quote m => m Pat -> m Type -> m Pat
sigP (Name -> Q Pat
forall (m :: * -> *). Quote m => Name -> m Pat
varP Name
nm) (Type -> Q Type
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
field))
                  [Name]
rhsFieldsPatNames
                  [Type]
rhsFields
              )
      let singleMatchPat =
            if [Type] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type]
lhsFields
              then Name -> [Q Pat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP (ConstructorInfo -> Name
constructorName ConstructorInfo
lhsConstructors) []
              else Name -> [Q FieldPat] -> Q Pat
forall (m :: * -> *). Quote m => Name -> [m FieldPat] -> m Pat
recP (ConstructorInfo -> Name
constructorName ConstructorInfo
rhsConstructors) []
      let lhsFieldPatExps = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Exp
VarE [Name]
lhsFieldsPatNames
      let rhsFieldPatExps = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> Exp
VarE [Name]
rhsFieldsPatNames

      fieldResExpsAndArgsUsed <-
        zipWithM
          (fieldResFun extraPatExps)
          (zip lhsFieldPatExps rhsFieldPatExps)
          defaultFieldFunExps
      let fieldResExps = (Exp, [Bool]) -> Exp
forall a b. (a, b) -> a
fst ((Exp, [Bool]) -> Exp) -> [(Exp, [Bool])] -> [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Exp, [Bool])]
fieldResExpsAndArgsUsed
      let extraArgsUsedByFields = (Exp, [Bool]) -> [Bool]
forall a b. (a, b) -> b
snd ((Exp, [Bool]) -> [Bool]) -> [(Exp, [Bool])] -> [[Bool]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Exp, [Bool])]
fieldResExpsAndArgsUsed
      (resExp, extraArgsUsedByResult) <-
        fieldCombineFun
          (constructorName lhsConstructors)
          fieldResExps

      let eqt TyVarBndr_ flag
l TyVarBndr_ flag
r =
            [|
              eqTypeRep
                (typeRep :: TypeRep $(Name -> m Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> m Type) -> Name -> m Type
forall a b. (a -> b) -> a -> b
$ TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndr_ flag
l))
                (typeRep :: TypeRep $(Name -> m Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> m Type) -> Name -> m Type
forall a b. (a -> b) -> a -> b
$ TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndr_ flag
r))
              |]
      let eqx Q Exp
trueCont TyVarBndr_ flag
l TyVarBndr_ flag
r = do
            cmp <-
              [|
                compare
                  (someTypeRep (Proxy :: Proxy $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> Name -> Q Type
forall a b. (a -> b) -> a -> b
$ TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndr_ flag
l)))
                  (someTypeRep (Proxy :: Proxy $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT (Name -> Q Type) -> Name -> Q Type
forall a b. (a -> b) -> a -> b
$ TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndr_ flag
r)))
                |]
            [|
              case $(eqt l r) of
                Just HRefl -> $(trueCont)
                _ ->
                  $(fieldDifferentExistentialFun cmp)
              |]
      let construct [] = Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
resExp
          construct ((TyVarBndr_ flag
l, TyVarBndr_ flag
r) : [(TyVarBndr_ flag, TyVarBndr_ flag)]
xs) = [|$(Q Exp -> TyVarBndr_ flag -> TyVarBndr_ flag -> Q Exp
forall {flag} {flag}.
Q Exp -> TyVarBndr_ flag -> TyVarBndr_ flag -> Q Exp
eqx ([(TyVarBndr_ flag, TyVarBndr_ flag)] -> Q Exp
construct [(TyVarBndr_ flag, TyVarBndr_ flag)]
xs) TyVarBndr_ flag
l TyVarBndr_ flag
r)|]

      let extraArgsUsed =
            ([Bool] -> Bool) -> [[Bool]] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([[Bool]] -> [Bool]) -> [[Bool]] -> [Bool]
forall a b. (a -> b) -> a -> b
$
              [[Bool]] -> [[Bool]]
forall a. [[a]] -> [[a]]
List.transpose ([[Bool]] -> [[Bool]]) -> [[Bool]] -> [[Bool]]
forall a b. (a -> b) -> a -> b
$
                [Bool]
extraArgsUsedByResult [Bool] -> [[Bool]] -> [[Bool]]
forall a. a -> [a] -> [a]
: [[Bool]]
extraArgsUsedByFields
      let extraArgsPats =
            (Pat -> Bool -> Pat) -> [Pat] -> [Bool] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              (\Pat
pat Bool
used -> if Bool
used then Pat
pat else Pat
WildP)
              [Pat]
extraPats
              [Bool]
extraArgsUsed
      bothMatched <-
        clause
          ((return <$> funPats ++ extraArgsPats) ++ [lhsFieldPats, rhsFieldPats])
          ( normalB
              [|
                $( construct $
                     zip
                       (constructorVars lhsConstructors)
                       (constructorVars rhsConstructors)
                 )
                |]
          )
          []
      lhsMatched <-
        clause
          ((wildP <$ funPats) ++ [singleMatchPat, wildP])
          (normalB [|$(fieldLMatchResult)|])
          []
      rhsMatched <-
        clause
          ((wildP <$ funPats) ++ [wildP, singleMatchPat])
          (normalB [|$(fieldRMatchResult)|])
          []
      if isLast
        then return [bothMatched]
        else return [bothMatched, lhsMatched, rhsMatched]

-- | Configuration for a binary operation type class generation on a GADT.
data BinaryOpClassConfig = BinaryOpClassConfig
  { BinaryOpClassConfig -> [BinaryOpFieldConfig]
binaryOpFieldConfigs :: [BinaryOpFieldConfig],
    BinaryOpClassConfig -> [Name]
binaryOpInstanceNames :: [Name],
    BinaryOpClassConfig -> Bool
binaryOpAllowSumType :: Bool,
    BinaryOpClassConfig -> Bool
binaryOpAllowExistential :: Bool
  }

-- | Generate a function for a binary operation on a GADT.
genBinaryOpFun ::
  BinaryOpFieldConfig ->
  Int ->
  [(Type, Kind)] ->
  [(Type, Kind)] ->
  [ConstructorInfo] ->
  [ConstructorInfo] ->
  Q Dec
genBinaryOpFun :: BinaryOpFieldConfig
-> Int
-> [(Type, Type)]
-> [(Type, Type)]
-> [ConstructorInfo]
-> [ConstructorInfo]
-> Q Dec
genBinaryOpFun BinaryOpFieldConfig
config Int
n [(Type, Type)]
_ [(Type, Type)]
_ [] [] =
  Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD
    (BinaryOpFieldConfig -> [Name]
fieldFunNames BinaryOpFieldConfig
config [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n)
    [[Q Pat] -> Q Body -> [Q Dec] -> Q Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (Q Exp -> Q Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|error "impossible"|]) []]
genBinaryOpFun
  BinaryOpFieldConfig
config
  Int
n
  [(Type, Type)]
lhsArgNewVars
  [(Type, Type)]
rhsArgNewVars
  [ConstructorInfo]
lhsConstructors
  [ConstructorInfo]
rhsConstructors = do
    clauses <-
      (ConstructorInfo -> ConstructorInfo -> Q [Clause])
-> [ConstructorInfo] -> [ConstructorInfo] -> Q [[Clause]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
        (BinaryOpFieldConfig
-> [(Type, Type)]
-> [(Type, Type)]
-> Bool
-> ConstructorInfo
-> ConstructorInfo
-> Q [Clause]
genBinaryOpClause BinaryOpFieldConfig
config [(Type, Type)]
lhsArgNewVars [(Type, Type)]
rhsArgNewVars Bool
False)
        ([ConstructorInfo] -> [ConstructorInfo]
forall a. HasCallStack => [a] -> [a]
init [ConstructorInfo]
lhsConstructors)
        ([ConstructorInfo] -> [ConstructorInfo]
forall a. HasCallStack => [a] -> [a]
init [ConstructorInfo]
rhsConstructors)
    lastClause <-
      genBinaryOpClause
        config
        lhsArgNewVars
        rhsArgNewVars
        True
        (last lhsConstructors)
        (last rhsConstructors)
    let instanceFunName = (BinaryOpFieldConfig -> [Name]
fieldFunNames BinaryOpFieldConfig
config) [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
    return $ FunD instanceFunName (concat clauses ++ lastClause)

-- | Generate a type class instance for a binary operation on a GADT.
genBinaryOpClass ::
  DeriveConfig -> BinaryOpClassConfig -> Int -> Name -> Q [Dec]
genBinaryOpClass :: DeriveConfig -> BinaryOpClassConfig -> Int -> Name -> Q [Dec]
genBinaryOpClass DeriveConfig
deriveConfig (BinaryOpClassConfig {Bool
[Name]
[BinaryOpFieldConfig]
binaryOpFieldConfigs :: BinaryOpClassConfig -> [BinaryOpFieldConfig]
binaryOpInstanceNames :: BinaryOpClassConfig -> [Name]
binaryOpAllowSumType :: BinaryOpClassConfig -> Bool
binaryOpAllowExistential :: BinaryOpClassConfig -> Bool
binaryOpFieldConfigs :: [BinaryOpFieldConfig]
binaryOpInstanceNames :: [Name]
binaryOpAllowSumType :: Bool
binaryOpAllowExistential :: Bool
..}) Int
n Name
typName = do
  lhsResult <-
    [(Int, EvalModeTag)] -> CheckArgsResult -> Q CheckArgsResult
specializeResult (DeriveConfig -> [(Int, EvalModeTag)]
evalModeSpecializeList DeriveConfig
deriveConfig)
      (CheckArgsResult -> Q CheckArgsResult)
-> Q CheckArgsResult -> Q CheckArgsResult
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool -> CheckArgsResult -> Q CheckArgsResult
freshenCheckArgsResult Bool
True
      (CheckArgsResult -> Q CheckArgsResult)
-> Q CheckArgsResult -> Q CheckArgsResult
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Int -> Name -> Bool -> Int -> Q CheckArgsResult
checkArgs
        (Name -> String
nameBase (Name -> String) -> Name -> String
forall a b. (a -> b) -> a -> b
$ [Name] -> Name
forall a. HasCallStack => [a] -> a
head [Name]
binaryOpInstanceNames)
        ([Name] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
binaryOpInstanceNames Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        Name
typName
        (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Bool
binaryOpAllowExistential)
        Int
n
  when (not binaryOpAllowSumType && length (constructors lhsResult) > 1) $
    fail $
      "Cannot derive "
        <> nameBase (binaryOpInstanceNames !! n)
        <> " for sum type"
  rhsResult <-
    specializeResult (evalModeSpecializeList deriveConfig)
      =<< checkArgs
        (nameBase $ head binaryOpInstanceNames)
        (length binaryOpInstanceNames - 1)
        typName
        (n == 0)
        n
  let keptVars' = CheckArgsResult -> [(Type, Type)]
keptVars CheckArgsResult
lhsResult
  when (typName == ''IdentityT) $
    fail $
      show keptVars'
  let isTypeUsedInFields' (VarT Name
nm) = CheckArgsResult -> Name -> Bool
isVarUsedInFields CheckArgsResult
lhsResult Name
nm
      isTypeUsedInFields' Type
_ = Bool
False
  ctxs <-
    traverse (uncurry $ ctxForVar (fmap ConT binaryOpInstanceNames)) $
      filter (isTypeUsedInFields' . fst) $
        fmap snd $
          filter (not . (`elem` unconstrainedPositions deriveConfig) . fst) $
            zip [0 ..] keptVars'
  let keptType = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
typName) ([Type] -> Type) -> [Type] -> Type
forall a b. (a -> b) -> a -> b
$ ((Type, Type) -> Type) -> [(Type, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type, Type) -> Type
forall a b. (a, b) -> a
fst [(Type, Type)]
keptVars'
  instanceFuns <-
    traverse
      ( \BinaryOpFieldConfig
config ->
          BinaryOpFieldConfig
-> Int
-> [(Type, Type)]
-> [(Type, Type)]
-> [ConstructorInfo]
-> [ConstructorInfo]
-> Q Dec
genBinaryOpFun
            BinaryOpFieldConfig
config
            Int
n
            (CheckArgsResult -> [(Type, Type)]
argVars CheckArgsResult
lhsResult)
            (CheckArgsResult -> [(Type, Type)]
argVars CheckArgsResult
rhsResult)
            (CheckArgsResult -> [ConstructorInfo]
constructors CheckArgsResult
lhsResult)
            (CheckArgsResult -> [ConstructorInfo]
constructors CheckArgsResult
rhsResult)
      )
      binaryOpFieldConfigs
  let instanceName = [Name]
binaryOpInstanceNames [Name] -> Int -> Name
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
  let instanceType = Type -> Type -> Type
AppT (Name -> Type
ConT Name
instanceName) Type
keptType
  extraPreds <-
    extraConstraint
      deriveConfig
      typName
      instanceName
      []
      keptVars'
      (constructors lhsResult)
  return
    [ InstanceD
        Nothing
        ( extraPreds
            ++ if null (constructors lhsResult)
              then []
              else catMaybes ctxs
        )
        instanceType
        instanceFuns
    ]