{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}

-- |
-- Module      :   Grisette.Internal.TH.Derivation.DeriveSymOrd
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.TH.Derivation.DeriveSymOrd
  ( deriveSymOrd,
    deriveSymOrd1,
    deriveSymOrd2,
  )
where

import Grisette.Internal.Internal.Decl.Core.Data.Class.SymOrd
  ( SymOrd (symCompare),
    SymOrd1 (liftSymCompare),
    SymOrd2 (liftSymCompare2),
  )
import Grisette.Internal.Internal.Decl.Core.Data.Class.TryMerge
  ( mrgSingle,
  )
import Grisette.Internal.TH.Derivation.BinaryOpCommon
  ( BinaryOpClassConfig
      ( BinaryOpClassConfig,
        binaryOpAllowSumType,
        binaryOpFieldConfigs,
        binaryOpInstanceNames
      ),
    BinaryOpFieldConfig
      ( BinaryOpFieldConfig,
        extraPatNames,
        fieldCombineFun,
        fieldDifferentExistentialFun,
        fieldFunExp,
        fieldFunNames,
        fieldLMatchResult,
        fieldRMatchResult,
        fieldResFun
      ),
    binaryOpAllowExistential,
    defaultFieldFunExp,
    genBinaryOpClass,
  )
import Grisette.Internal.TH.Derivation.Common (DeriveConfig)
import Language.Haskell.TH (Dec, Name, Q)

symOrdConfig :: BinaryOpClassConfig
symOrdConfig :: BinaryOpClassConfig
symOrdConfig =
  BinaryOpClassConfig
    { binaryOpFieldConfigs :: [BinaryOpFieldConfig]
binaryOpFieldConfigs =
        [ BinaryOpFieldConfig
            { extraPatNames :: [String]
extraPatNames = [],
              fieldResFun :: [Exp] -> (Exp, Exp) -> Exp -> Q (Exp, [Bool])
fieldResFun =
                \[Exp]
_ (Exp
lhs, Exp
rhs) Exp
f ->
                  (,[]) (Exp -> (Exp, [Bool])) -> Q Exp -> Q (Exp, [Bool])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [|$(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
f) $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
lhs) $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
rhs)|],
              fieldCombineFun :: Name -> [Exp] -> Q (Exp, [Bool])
fieldCombineFun =
                \Name
_ [Exp]
lst -> do
                  let go :: [Exp] -> m Exp
go [] = [|mrgSingle EQ|]
                      go [Exp
x] = [|$(Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
x)|]
                      go (Exp
x : [Exp]
xs) =
                        [|
                          do
                            a <- $(Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
x)
                            case a of
                              EQ -> $([Exp] -> m Exp
go [Exp]
xs)
                              _ -> mrgSingle a
                          |]
                  (,[]) (Exp -> (Exp, [Bool])) -> Q Exp -> Q (Exp, [Bool])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Exp] -> Q Exp
forall {m :: * -> *}. Quote m => [Exp] -> m Exp
go [Exp]
lst,
              fieldDifferentExistentialFun :: Exp -> Q Exp
fieldDifferentExistentialFun =
                \Exp
exp -> [|mrgSingle $(Exp -> Q Exp
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
exp)|],
              fieldFunExp :: FieldFunExp
fieldFunExp =
                [Name] -> FieldFunExp
defaultFieldFunExp
                  ['symCompare, 'liftSymCompare, 'liftSymCompare2],
              fieldFunNames :: [Name]
fieldFunNames = ['symCompare, 'liftSymCompare, 'liftSymCompare2],
              fieldLMatchResult :: Q Exp
fieldLMatchResult = [|mrgSingle LT|],
              fieldRMatchResult :: Q Exp
fieldRMatchResult = [|mrgSingle GT|]
            }
        ],
      binaryOpInstanceNames :: [Name]
binaryOpInstanceNames = [''SymOrd, ''SymOrd1, ''SymOrd2],
      binaryOpAllowSumType :: Bool
binaryOpAllowSumType = Bool
True,
      binaryOpAllowExistential :: Bool
binaryOpAllowExistential = Bool
True
    }

-- | Derive 'SymOrd' instance for a data type.
deriveSymOrd :: DeriveConfig -> Name -> Q [Dec]
deriveSymOrd :: DeriveConfig -> Name -> Q [Dec]
deriveSymOrd DeriveConfig
deriveConfig = DeriveConfig -> BinaryOpClassConfig -> Int -> Name -> Q [Dec]
genBinaryOpClass DeriveConfig
deriveConfig BinaryOpClassConfig
symOrdConfig Int
0

-- | Derive 'SymOrd1' instance for a data type.
deriveSymOrd1 :: DeriveConfig -> Name -> Q [Dec]
deriveSymOrd1 :: DeriveConfig -> Name -> Q [Dec]
deriveSymOrd1 DeriveConfig
deriveConfig = DeriveConfig -> BinaryOpClassConfig -> Int -> Name -> Q [Dec]
genBinaryOpClass DeriveConfig
deriveConfig BinaryOpClassConfig
symOrdConfig Int
1

-- | Derive 'SymOrd2' instance for a data type.
deriveSymOrd2 :: DeriveConfig -> Name -> Q [Dec]
deriveSymOrd2 :: DeriveConfig -> Name -> Q [Dec]
deriveSymOrd2 DeriveConfig
deriveConfig = DeriveConfig -> BinaryOpClassConfig -> Int -> Name -> Q [Dec]
genBinaryOpClass DeriveConfig
deriveConfig BinaryOpClassConfig
symOrdConfig Int
2