{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Internal.Decl.Core.Data.Class.SimpleMergeable
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Internal.Decl.Core.Data.Class.SimpleMergeable
  ( -- * Simple mergeable types
    SimpleMergeable (..),
    SimpleMergeable1 (..),
    mrgIte1,
    SimpleMergeable2 (..),
    mrgIte2,

    -- * Generic 'SimpleMergeable'
    SimpleMergeableArgs (..),
    GSimpleMergeable (..),
    genericMrgIte,
    genericLiftMrgIte,

    -- * Symbolic branching
    SymBranching (..),
    mrgIf,
    mergeWithStrategy,
    merge,
  )
where

import Data.Kind (Type)
import GHC.Generics
  ( Generic (Rep, from, to),
    Generic1 (Rep1, from1, to1),
    K1 (K1),
    M1 (M1),
    Par1 (Par1),
    Rec1 (Rec1),
    U1,
    V1,
    (:.:) (Comp1),
    type (:*:) ((:*:)),
  )
import Generics.Deriving (Default (Default), Default1 (Default1))
import Grisette.Internal.Core.Data.Class.ITEOp (ITEOp (symIte))
import Grisette.Internal.Internal.Decl.Core.Data.Class.Mergeable
  ( GMergeable,
    Mergeable (rootStrategy),
    Mergeable1,
    Mergeable2,
    MergingStrategy,
  )
import Grisette.Internal.Internal.Decl.Core.Data.Class.TryMerge
  ( TryMerge (tryMergeWithStrategy),
  )
import Grisette.Internal.SymPrim.SymBool (SymBool)
import Grisette.Internal.Utils.Derive (Arity0, Arity1)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Control.Monad.Identity

-- | This class indicates that a type has a simple root merge strategy.
--
-- __Note:__ This type class can be derived for algebraic data types.
-- You may need the @DerivingVia@ and @DerivingStrategies@ extensions.
--
-- > data X = ...
-- >   deriving Generic
-- >   deriving (Mergeable, SimpleMergeable) via (Default X)
class (Mergeable a) => SimpleMergeable a where
  -- | Performs if-then-else with the simple root merge strategy.
  --
  -- >>> mrgIte "a" "b" "c" :: SymInteger
  -- (ite a b c)
  mrgIte :: SymBool -> a -> a -> a

-- | Lifting of the 'SimpleMergeable' class to unary type constructors.
class
  (Mergeable1 u, forall a. (SimpleMergeable a) => (SimpleMergeable (u a))) =>
  SimpleMergeable1 u
  where
  -- | Lift 'mrgIte' through the type constructor.
  --
  -- >>> liftMrgIte mrgIte "a" (Identity "b") (Identity "c") :: Identity SymInteger
  -- Identity (ite a b c)
  liftMrgIte :: (SymBool -> a -> a -> a) -> SymBool -> u a -> u a -> u a

-- | Lift the standard 'mrgIte' function through the type constructor.
--
-- >>> mrgIte1 "a" (Identity "b") (Identity "c") :: Identity SymInteger
-- Identity (ite a b c)
mrgIte1 ::
  (SimpleMergeable1 u, SimpleMergeable a) => SymBool -> u a -> u a -> u a
mrgIte1 :: forall (u :: * -> *) a.
(SimpleMergeable1 u, SimpleMergeable a) =>
SymBool -> u a -> u a -> u a
mrgIte1 = (SymBool -> a -> a -> a) -> SymBool -> u a -> u a -> u a
forall a. (SymBool -> a -> a -> a) -> SymBool -> u a -> u a -> u a
forall (u :: * -> *) a.
SimpleMergeable1 u =>
(SymBool -> a -> a -> a) -> SymBool -> u a -> u a -> u a
liftMrgIte SymBool -> a -> a -> a
forall a. SimpleMergeable a => SymBool -> a -> a -> a
mrgIte
{-# INLINE mrgIte1 #-}

-- | Lifting of the 'SimpleMergeable' class to binary type constructors.
class
  (Mergeable2 u, forall a. (SimpleMergeable a) => SimpleMergeable1 (u a)) =>
  SimpleMergeable2 u
  where
  -- | Lift 'mrgIte' through the type constructor.
  --
  -- >>> liftMrgIte2 mrgIte mrgIte "a" ("b", "c") ("d", "e") :: (SymInteger, SymBool)
  -- ((ite a b d),(ite a c e))
  liftMrgIte2 ::
    (SymBool -> a -> a -> a) ->
    (SymBool -> b -> b -> b) ->
    SymBool ->
    u a b ->
    u a b ->
    u a b

-- | Lift the standard 'mrgIte' function through the type constructor.
--
-- >>> mrgIte2 "a" ("b", "c") ("d", "e") :: (SymInteger, SymBool)
-- ((ite a b d),(ite a c e))
mrgIte2 ::
  (SimpleMergeable2 u, SimpleMergeable a, SimpleMergeable b) =>
  SymBool ->
  u a b ->
  u a b ->
  u a b
mrgIte2 :: forall (u :: * -> * -> *) a b.
(SimpleMergeable2 u, SimpleMergeable a, SimpleMergeable b) =>
SymBool -> u a b -> u a b -> u a b
mrgIte2 = (SymBool -> a -> a -> a)
-> (SymBool -> b -> b -> b) -> SymBool -> u a b -> u a b -> u a b
forall a b.
(SymBool -> a -> a -> a)
-> (SymBool -> b -> b -> b) -> SymBool -> u a b -> u a b -> u a b
forall (u :: * -> * -> *) a b.
SimpleMergeable2 u =>
(SymBool -> a -> a -> a)
-> (SymBool -> b -> b -> b) -> SymBool -> u a b -> u a b -> u a b
liftMrgIte2 SymBool -> a -> a -> a
forall a. SimpleMergeable a => SymBool -> a -> a -> a
mrgIte SymBool -> b -> b -> b
forall a. SimpleMergeable a => SymBool -> a -> a -> a
mrgIte
{-# INLINE mrgIte2 #-}

-- | The arguments to the generic simple merging function.
data family SimpleMergeableArgs arity a :: Type

data instance SimpleMergeableArgs Arity0 _ = SimpleMergeableArgs0

newtype instance SimpleMergeableArgs Arity1 a
  = SimpleMergeableArgs1 (SymBool -> a -> a -> a)

-- | Generic 'SimpleMergeable' class.
class GSimpleMergeable arity f where
  gmrgIte :: SimpleMergeableArgs arity a -> SymBool -> f a -> f a -> f a

instance GSimpleMergeable arity V1 where
  gmrgIte :: forall a.
SimpleMergeableArgs arity a -> SymBool -> V1 a -> V1 a -> V1 a
gmrgIte SimpleMergeableArgs arity a
_ SymBool
_ V1 a
t V1 a
_ = V1 a
t
  {-# INLINE gmrgIte #-}

instance (GSimpleMergeable arity U1) where
  gmrgIte :: forall a.
SimpleMergeableArgs arity a -> SymBool -> U1 a -> U1 a -> U1 a
gmrgIte SimpleMergeableArgs arity a
_ SymBool
_ U1 a
t U1 a
_ = U1 a
t
  {-# INLINE gmrgIte #-}

instance
  (GSimpleMergeable arity a, GSimpleMergeable arity b) =>
  (GSimpleMergeable arity (a :*: b))
  where
  gmrgIte :: forall a.
SimpleMergeableArgs arity a
-> SymBool -> (:*:) a b a -> (:*:) a b a -> (:*:) a b a
gmrgIte SimpleMergeableArgs arity a
args SymBool
cond (a a
a1 :*: b a
a2) (a a
b1 :*: b a
b2) =
    SimpleMergeableArgs arity a -> SymBool -> a a -> a a -> a a
forall a.
SimpleMergeableArgs arity a -> SymBool -> a a -> a a -> a a
forall arity (f :: * -> *) a.
GSimpleMergeable arity f =>
SimpleMergeableArgs arity a -> SymBool -> f a -> f a -> f a
gmrgIte SimpleMergeableArgs arity a
args SymBool
cond a a
a1 a a
b1 a a -> b a -> (:*:) a b a
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: SimpleMergeableArgs arity a -> SymBool -> b a -> b a -> b a
forall a.
SimpleMergeableArgs arity a -> SymBool -> b a -> b a -> b a
forall arity (f :: * -> *) a.
GSimpleMergeable arity f =>
SimpleMergeableArgs arity a -> SymBool -> f a -> f a -> f a
gmrgIte SimpleMergeableArgs arity a
args SymBool
cond b a
a2 b a
b2
  {-# INLINE gmrgIte #-}

instance (GSimpleMergeable arity a) => (GSimpleMergeable arity (M1 i c a)) where
  gmrgIte :: forall a.
SimpleMergeableArgs arity a
-> SymBool -> M1 i c a a -> M1 i c a a -> M1 i c a a
gmrgIte SimpleMergeableArgs arity a
args SymBool
cond (M1 a a
a) (M1 a a
b) = a a -> M1 i c a a
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (a a -> M1 i c a a) -> a a -> M1 i c a a
forall a b. (a -> b) -> a -> b
$ SimpleMergeableArgs arity a -> SymBool -> a a -> a a -> a a
forall a.
SimpleMergeableArgs arity a -> SymBool -> a a -> a a -> a a
forall arity (f :: * -> *) a.
GSimpleMergeable arity f =>
SimpleMergeableArgs arity a -> SymBool -> f a -> f a -> f a
gmrgIte SimpleMergeableArgs arity a
args SymBool
cond a a
a a a
b
  {-# INLINE gmrgIte #-}

instance (SimpleMergeable c) => (GSimpleMergeable arity (K1 i c)) where
  gmrgIte :: forall a.
SimpleMergeableArgs arity a
-> SymBool -> K1 i c a -> K1 i c a -> K1 i c a
gmrgIte SimpleMergeableArgs arity a
_ SymBool
cond (K1 c
a) (K1 c
b) = c -> K1 i c a
forall k i c (p :: k). c -> K1 i c p
K1 (c -> K1 i c a) -> c -> K1 i c a
forall a b. (a -> b) -> a -> b
$ SymBool -> c -> c -> c
forall a. SimpleMergeable a => SymBool -> a -> a -> a
mrgIte SymBool
cond c
a c
b
  {-# INLINE gmrgIte #-}

instance GSimpleMergeable Arity1 Par1 where
  gmrgIte :: forall a.
SimpleMergeableArgs Arity1 a
-> SymBool -> Par1 a -> Par1 a -> Par1 a
gmrgIte (SimpleMergeableArgs1 SymBool -> a -> a -> a
f) SymBool
cond (Par1 a
l) (Par1 a
r) = a -> Par1 a
forall p. p -> Par1 p
Par1 (a -> Par1 a) -> a -> Par1 a
forall a b. (a -> b) -> a -> b
$ SymBool -> a -> a -> a
f SymBool
cond a
l a
r
  {-# INLINE gmrgIte #-}

instance (SimpleMergeable1 f) => GSimpleMergeable Arity1 (Rec1 f) where
  gmrgIte :: forall a.
SimpleMergeableArgs Arity1 a
-> SymBool -> Rec1 f a -> Rec1 f a -> Rec1 f a
gmrgIte (SimpleMergeableArgs1 SymBool -> a -> a -> a
f) SymBool
cond (Rec1 f a
l) (Rec1 f a
r) =
    f a -> Rec1 f a
forall k (f :: k -> *) (p :: k). f p -> Rec1 f p
Rec1 (f a -> Rec1 f a) -> f a -> Rec1 f a
forall a b. (a -> b) -> a -> b
$ (SymBool -> a -> a -> a) -> SymBool -> f a -> f a -> f a
forall a. (SymBool -> a -> a -> a) -> SymBool -> f a -> f a -> f a
forall (u :: * -> *) a.
SimpleMergeable1 u =>
(SymBool -> a -> a -> a) -> SymBool -> u a -> u a -> u a
liftMrgIte SymBool -> a -> a -> a
f SymBool
cond f a
l f a
r
  {-# INLINE gmrgIte #-}

instance
  (SimpleMergeable1 f, GSimpleMergeable Arity1 g) =>
  GSimpleMergeable Arity1 (f :.: g)
  where
  gmrgIte :: forall a.
SimpleMergeableArgs Arity1 a
-> SymBool -> (:.:) f g a -> (:.:) f g a -> (:.:) f g a
gmrgIte SimpleMergeableArgs Arity1 a
targs SymBool
cond (Comp1 f (g a)
l) (Comp1 f (g a)
r) =
    f (g a) -> (:.:) f g a
forall k2 k1 (f :: k2 -> *) (g :: k1 -> k2) (p :: k1).
f (g p) -> (:.:) f g p
Comp1 (f (g a) -> (:.:) f g a) -> f (g a) -> (:.:) f g a
forall a b. (a -> b) -> a -> b
$ (SymBool -> g a -> g a -> g a)
-> SymBool -> f (g a) -> f (g a) -> f (g a)
forall a. (SymBool -> a -> a -> a) -> SymBool -> f a -> f a -> f a
forall (u :: * -> *) a.
SimpleMergeable1 u =>
(SymBool -> a -> a -> a) -> SymBool -> u a -> u a -> u a
liftMrgIte (SimpleMergeableArgs Arity1 a -> SymBool -> g a -> g a -> g a
forall a.
SimpleMergeableArgs Arity1 a -> SymBool -> g a -> g a -> g a
forall arity (f :: * -> *) a.
GSimpleMergeable arity f =>
SimpleMergeableArgs arity a -> SymBool -> f a -> f a -> f a
gmrgIte SimpleMergeableArgs Arity1 a
targs) SymBool
cond f (g a)
l f (g a)
r
  {-# INLINE gmrgIte #-}

instance
  (Generic a, GSimpleMergeable Arity0 (Rep a), GMergeable Arity0 (Rep a)) =>
  SimpleMergeable (Default a)
  where
  mrgIte :: SymBool -> Default a -> Default a -> Default a
mrgIte SymBool
cond (Default a
a) (Default a
b) =
    a -> Default a
forall a. a -> Default a
Default (a -> Default a) -> a -> Default a
forall a b. (a -> b) -> a -> b
$ SymBool -> a -> a -> a
forall a.
(Generic a, GSimpleMergeable Arity0 (Rep a)) =>
SymBool -> a -> a -> a
genericMrgIte SymBool
cond a
a a
b
  {-# INLINE mrgIte #-}

-- | Generic 'mrgIte' function.
genericMrgIte ::
  (Generic a, GSimpleMergeable Arity0 (Rep a)) =>
  SymBool ->
  a ->
  a ->
  a
genericMrgIte :: forall a.
(Generic a, GSimpleMergeable Arity0 (Rep a)) =>
SymBool -> a -> a -> a
genericMrgIte SymBool
cond a
a a
b =
  Rep a Any -> a
forall a x. Generic a => Rep a x -> a
forall x. Rep a x -> a
to (Rep a Any -> a) -> Rep a Any -> a
forall a b. (a -> b) -> a -> b
$ SimpleMergeableArgs Arity0 Any
-> SymBool -> Rep a Any -> Rep a Any -> Rep a Any
forall a.
SimpleMergeableArgs Arity0 a
-> SymBool -> Rep a a -> Rep a a -> Rep a a
forall arity (f :: * -> *) a.
GSimpleMergeable arity f =>
SimpleMergeableArgs arity a -> SymBool -> f a -> f a -> f a
gmrgIte SimpleMergeableArgs Arity0 Any
forall _. SimpleMergeableArgs Arity0 _
SimpleMergeableArgs0 SymBool
cond (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
a) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
b)
{-# INLINE genericMrgIte #-}

instance
  ( Generic1 f,
    GSimpleMergeable Arity1 (Rep1 f),
    GMergeable Arity1 (Rep1 f),
    SimpleMergeable a
  ) =>
  SimpleMergeable (Default1 f a)
  where
  mrgIte :: SymBool -> Default1 f a -> Default1 f a -> Default1 f a
mrgIte = SymBool -> Default1 f a -> Default1 f a -> Default1 f a
forall (u :: * -> *) a.
(SimpleMergeable1 u, SimpleMergeable a) =>
SymBool -> u a -> u a -> u a
mrgIte1
  {-# INLINE mrgIte #-}

instance
  (Generic1 f, GSimpleMergeable Arity1 (Rep1 f), GMergeable Arity1 (Rep1 f)) =>
  SimpleMergeable1 (Default1 f)
  where
  liftMrgIte :: forall a.
(SymBool -> a -> a -> a)
-> SymBool -> Default1 f a -> Default1 f a -> Default1 f a
liftMrgIte SymBool -> a -> a -> a
f SymBool
c (Default1 f a
l) (Default1 f a
r) =
    f a -> Default1 f a
forall (f :: * -> *) a. f a -> Default1 f a
Default1 (f a -> Default1 f a) -> f a -> Default1 f a
forall a b. (a -> b) -> a -> b
$ (SymBool -> a -> a -> a) -> SymBool -> f a -> f a -> f a
forall (f :: * -> *) a.
(Generic1 f, GSimpleMergeable Arity1 (Rep1 f)) =>
(SymBool -> a -> a -> a) -> SymBool -> f a -> f a -> f a
genericLiftMrgIte SymBool -> a -> a -> a
f SymBool
c f a
l f a
r
  {-# INLINE liftMrgIte #-}

-- | Generic 'liftMrgIte' function.
genericLiftMrgIte ::
  (Generic1 f, GSimpleMergeable Arity1 (Rep1 f)) =>
  (SymBool -> a -> a -> a) ->
  SymBool ->
  f a ->
  f a ->
  f a
genericLiftMrgIte :: forall (f :: * -> *) a.
(Generic1 f, GSimpleMergeable Arity1 (Rep1 f)) =>
(SymBool -> a -> a -> a) -> SymBool -> f a -> f a -> f a
genericLiftMrgIte SymBool -> a -> a -> a
f SymBool
c f a
l f a
r =
  Rep1 f a -> f a
forall a. Rep1 f a -> f a
forall k (f :: k -> *) (a :: k). Generic1 f => Rep1 f a -> f a
to1 (Rep1 f a -> f a) -> Rep1 f a -> f a
forall a b. (a -> b) -> a -> b
$ SimpleMergeableArgs Arity1 a
-> SymBool -> Rep1 f a -> Rep1 f a -> Rep1 f a
forall a.
SimpleMergeableArgs Arity1 a
-> SymBool -> Rep1 f a -> Rep1 f a -> Rep1 f a
forall arity (f :: * -> *) a.
GSimpleMergeable arity f =>
SimpleMergeableArgs arity a -> SymBool -> f a -> f a -> f a
gmrgIte ((SymBool -> a -> a -> a) -> SimpleMergeableArgs Arity1 a
forall a. (SymBool -> a -> a -> a) -> SimpleMergeableArgs Arity1 a
SimpleMergeableArgs1 SymBool -> a -> a -> a
f) SymBool
c (f a -> Rep1 f a
forall a. f a -> Rep1 f a
forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1 f a
l) (f a -> Rep1 f a
forall a. f a -> Rep1 f a
forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1 f a
r)
{-# INLINE genericLiftMrgIte #-}

-- | Special case of the 'Mergeable1' and 'SimpleMergeable1' class for type
-- constructors that are 'SimpleMergeable' when applied to any 'Mergeable'
-- types.
--
-- This type class is used to generalize the 'mrgIf' function to other
-- containers, for example, monad transformer transformed Unions.
class
  ( SimpleMergeable1 u,
    forall a. (Mergeable a) => SimpleMergeable (u a),
    TryMerge u
  ) =>
  SymBranching (u :: Type -> Type)
  where
  -- | Symbolic @if@ control flow with the result merged with some merge
  -- strategy.
  --
  -- >>> mrgIfWithStrategy rootStrategy "a" (mrgSingle "b") (return "c") :: Union SymInteger
  -- {(ite a b c)}
  --
  -- __Note:__ Be careful to call this directly in your code.
  -- The supplied merge strategy should be consistent with the type's root merge
  -- strategy, or some internal invariants would be broken and the program can
  -- crash.
  --
  -- This function is to be called when the 'Mergeable' constraint can not be
  -- resolved, e.g., the merge strategy for the contained type is given with
  -- 'Mergeable1'. In other cases, 'mrgIf' is usually a better alternative.
  mrgIfWithStrategy :: MergingStrategy a -> SymBool -> u a -> u a -> u a

  -- | Symbolic @if@ control flow with the result.
  --
  -- This function does not need a merging strategy, and it will merge the
  -- result only if any of the branches is merged.
  mrgIfPropagatedStrategy :: SymBool -> u a -> u a -> u a

-- | Try to merge the container with a given merge strategy.
mergeWithStrategy :: (SymBranching m) => MergingStrategy a -> m a -> m a
mergeWithStrategy :: forall (m :: * -> *) a.
SymBranching m =>
MergingStrategy a -> m a -> m a
mergeWithStrategy = MergingStrategy a -> m a -> m a
forall a. MergingStrategy a -> m a -> m a
forall (m :: * -> *) a.
TryMerge m =>
MergingStrategy a -> m a -> m a
tryMergeWithStrategy
{-# INLINE mergeWithStrategy #-}

-- | Try to merge the container with the root strategy.
merge :: (SymBranching m, Mergeable a) => m a -> m a
merge :: forall (m :: * -> *) a. (SymBranching m, Mergeable a) => m a -> m a
merge = MergingStrategy a -> m a -> m a
forall (m :: * -> *) a.
SymBranching m =>
MergingStrategy a -> m a -> m a
mergeWithStrategy MergingStrategy a
forall a. Mergeable a => MergingStrategy a
rootStrategy
{-# INLINE merge #-}

-- | Symbolic @if@ control flow with the result merged with the type's root
-- merge strategy.
--
-- Equivalent to @'mrgIfWithStrategy' 'rootStrategy'@.
--
-- >>> mrgIf "a" (return "b") (return "c") :: Union SymInteger
-- {(ite a b c)}
mrgIf :: (SymBranching u, Mergeable a) => SymBool -> u a -> u a -> u a
mrgIf :: forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf = MergingStrategy a -> SymBool -> u a -> u a -> u a
forall a. MergingStrategy a -> SymBool -> u a -> u a -> u a
forall (u :: * -> *) a.
SymBranching u =>
MergingStrategy a -> SymBool -> u a -> u a -> u a
mrgIfWithStrategy MergingStrategy a
forall a. Mergeable a => MergingStrategy a
rootStrategy
{-# INLINE mrgIf #-}

instance SimpleMergeable SymBool where
  mrgIte :: SymBool -> SymBool -> SymBool -> SymBool
mrgIte = SymBool -> SymBool -> SymBool -> SymBool
forall v. ITEOp v => SymBool -> v -> v -> v
symIte
  {-# INLINE mrgIte #-}