{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Internal.Decl.Core.Data.Class.SymEq
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Internal.Decl.Core.Data.Class.SymEq
  ( -- * Symbolic equality
    SymEq (..),
    SymEq1 (..),
    symEq1,
    SymEq2 (..),
    symEq2,
    pairwiseSymDistinct,

    -- * More 'Eq' helper
    distinct,

    -- * Generic 'SymEq'
    SymEqArgs (..),
    GSymEq (..),
    genericSymEq,
    genericLiftSymEq,
  )
where

import Data.Kind (Type)
import Generics.Deriving
  ( Default (Default),
    Default1 (Default1),
    Generic (Rep, from),
    Generic1 (Rep1, from1),
    K1 (K1),
    M1 (M1),
    Par1 (Par1),
    Rec1 (Rec1),
    U1,
    V1,
    (:.:) (Comp1),
    type (:*:) ((:*:)),
    type (:+:) (L1, R1),
  )
import Grisette.Internal.Core.Data.Class.LogicalOp (LogicalOp (symNot, (.&&)))
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.SymPrim.SymBool (SymBool)
import Grisette.Internal.Utils.Derive (Arity0, Arity1)

-- | Check if all elements in a list are distinct.
--
-- Note that empty or singleton lists are always distinct.
--
-- >>> distinct []
-- True
-- >>> distinct [1]
-- True
-- >>> distinct [1, 2, 3]
-- True
-- >>> distinct [1, 2, 2]
-- False
distinct :: (Eq a) => [a] -> Bool
distinct :: forall a. Eq a => [a] -> Bool
distinct [] = Bool
True
distinct [a
_] = Bool
True
distinct (a
x : [a]
xs) = a -> [a] -> Bool
forall {t}. Eq t => t -> [t] -> Bool
go a
x [a]
xs Bool -> Bool -> Bool
&& [a] -> Bool
forall a. Eq a => [a] -> Bool
distinct [a]
xs
  where
    go :: t -> [t] -> Bool
go t
_ [] = Bool
True
    go t
x' (t
y : [t]
ys) = t
x' t -> t -> Bool
forall a. Eq a => a -> a -> Bool
/= t
y Bool -> Bool -> Bool
forall b. LogicalOp b => b -> b -> b
.&& t -> [t] -> Bool
go t
x' [t]
ys

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim

-- | Symbolic equality. Note that we can't use Haskell's 'Eq' class since
-- symbolic comparison won't necessarily return a concrete 'Bool' value.
--
-- >>> let a = 1 :: SymInteger
-- >>> let b = 2 :: SymInteger
-- >>> a .== b
-- false
-- >>> a ./= b
-- true
--
-- >>> let a = "a" :: SymInteger
-- >>> let b = "b" :: SymInteger
-- >>> a .== b
-- (= a b)
-- >>> a ./= b
-- (distinct a b)
--
-- __Note:__ This type class can be derived for algebraic data types.
-- You may need the @DerivingVia@ and @DerivingStrategies@ extensions.
--
-- > data X = ... deriving Generic deriving SymEq via (Default X)
class SymEq a where
  (.==) :: a -> a -> SymBool
  a
a .== a
b = SymBool -> SymBool
forall b. LogicalOp b => b -> b
symNot (SymBool -> SymBool) -> SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ a
a a -> a -> SymBool
forall a. SymEq a => a -> a -> SymBool
./= a
b
  {-# INLINE (.==) #-}
  infix 4 .==

  (./=) :: a -> a -> SymBool
  a
a ./= a
b = SymBool -> SymBool
forall b. LogicalOp b => b -> b
symNot (SymBool -> SymBool) -> SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ a
a a -> a -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== a
b
  {-# INLINE (./=) #-}
  infix 4 ./=

  -- | Check if all elements in a list are distinct, under the symbolic equality
  -- semantics.
  symDistinct :: [a] -> SymBool
  symDistinct = [a] -> SymBool
forall a. SymEq a => [a] -> SymBool
pairwiseSymDistinct

  {-# MINIMAL (.==) | (./=) #-}

-- | Default pairwise symbolic distinct implementation.
pairwiseSymDistinct :: (SymEq a) => [a] -> SymBool
pairwiseSymDistinct :: forall a. SymEq a => [a] -> SymBool
pairwiseSymDistinct [] = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
pairwiseSymDistinct [a
_] = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
pairwiseSymDistinct (a
x : [a]
xs) = a -> [a] -> SymBool
forall {t}. SymEq t => t -> [t] -> SymBool
go a
x [a]
xs SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& [a] -> SymBool
forall a. SymEq a => [a] -> SymBool
pairwiseSymDistinct [a]
xs
  where
    go :: t -> [t] -> SymBool
go t
_ [] = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
    go t
x' (t
y : [t]
ys) = t
x' t -> t -> SymBool
forall a. SymEq a => a -> a -> SymBool
./= t
y SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& t -> [t] -> SymBool
go t
x' [t]
ys

-- | Lifting of the 'SymEq' class to unary type constructors.
--
-- Any instance should be subject to the following law that canonicity is
-- preserved:
--
-- @liftSymEq (.==)@ should be equivalent to @(.==)@, under the symbolic
-- semantics.
--
-- This class therefore represents the generalization of 'SymEq' by decomposing
-- its main method into a canonical lifting on a canonical inner method, so that
-- the lifting can be reused for other arguments than the canonical one.
class (forall a. (SymEq a) => SymEq (f a)) => SymEq1 f where
  -- | Lift a symbolic equality test through the type constructor.
  --
  -- The function will usually be applied to an symbolic equality function, but
  -- the more general type ensures that the implementation uses it to compare
  -- elements of the first container with elements of the second.
  liftSymEq :: (a -> b -> SymBool) -> f a -> f b -> SymBool

-- | Lift the standard @('.==')@ function through the type constructor.
symEq1 :: (SymEq a, SymEq1 f) => f a -> f a -> SymBool
symEq1 :: forall a (f :: * -> *).
(SymEq a, SymEq1 f) =>
f a -> f a -> SymBool
symEq1 = (a -> a -> SymBool) -> f a -> f a -> SymBool
forall a b. (a -> b -> SymBool) -> f a -> f b -> SymBool
forall (f :: * -> *) a b.
SymEq1 f =>
(a -> b -> SymBool) -> f a -> f b -> SymBool
liftSymEq a -> a -> SymBool
forall a. SymEq a => a -> a -> SymBool
(.==)

-- | Lifting of the 'SymEq' class to binary type constructors.
class (forall a. (SymEq a) => SymEq1 (f a)) => SymEq2 f where
  -- | Lift symbolic equality tests through the type constructor.
  --
  -- The function will usually be applied to an symbolic equality function, but
  -- the more general type ensures that the implementation uses it to compare
  -- elements of the first container with elements of the second.
  liftSymEq2 ::
    (a -> b -> SymBool) ->
    (c -> d -> SymBool) ->
    f a c ->
    f b d ->
    SymBool

-- | Lift the standard @('.==')@ function through the type constructor.
symEq2 :: (SymEq a, SymEq b, SymEq2 f) => f a b -> f a b -> SymBool
symEq2 :: forall a b (f :: * -> * -> *).
(SymEq a, SymEq b, SymEq2 f) =>
f a b -> f a b -> SymBool
symEq2 = (a -> a -> SymBool)
-> (b -> b -> SymBool) -> f a b -> f a b -> SymBool
forall a b c d.
(a -> b -> SymBool)
-> (c -> d -> SymBool) -> f a c -> f b d -> SymBool
forall (f :: * -> * -> *) a b c d.
SymEq2 f =>
(a -> b -> SymBool)
-> (c -> d -> SymBool) -> f a c -> f b d -> SymBool
liftSymEq2 a -> a -> SymBool
forall a. SymEq a => a -> a -> SymBool
(.==) b -> b -> SymBool
forall a. SymEq a => a -> a -> SymBool
(.==)

-- Derivations

-- | The arguments to the generic equality function.
data family SymEqArgs arity a b :: Type

data instance SymEqArgs Arity0 _ _ = SymEqArgs0

newtype instance SymEqArgs Arity1 a b = SymEqArgs1 (a -> b -> SymBool)

-- | The class of types that can be generically compared for symbolic equality.
class GSymEq arity f where
  gsymEq :: SymEqArgs arity a b -> f a -> f b -> SymBool

instance GSymEq arity V1 where
  gsymEq :: forall a b. SymEqArgs arity a b -> V1 a -> V1 b -> SymBool
gsymEq SymEqArgs arity a b
_ V1 a
_ V1 b
_ = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE gsymEq #-}

instance GSymEq arity U1 where
  gsymEq :: forall a b. SymEqArgs arity a b -> U1 a -> U1 b -> SymBool
gsymEq SymEqArgs arity a b
_ U1 a
_ U1 b
_ = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE gsymEq #-}

instance (GSymEq arity a, GSymEq arity b) => GSymEq arity (a :*: b) where
  gsymEq :: forall a b.
SymEqArgs arity a b -> (:*:) a b a -> (:*:) a b b -> SymBool
gsymEq SymEqArgs arity a b
args (a a
a1 :*: b a
b1) (a b
a2 :*: b b
b2) = SymEqArgs arity a b -> a a -> a b -> SymBool
forall a b. SymEqArgs arity a b -> a a -> a b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq SymEqArgs arity a b
args a a
a1 a b
a2 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymEqArgs arity a b -> b a -> b b -> SymBool
forall a b. SymEqArgs arity a b -> b a -> b b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq SymEqArgs arity a b
args b a
b1 b b
b2
  {-# INLINE gsymEq #-}

instance (GSymEq arity a, GSymEq arity b) => GSymEq arity (a :+: b) where
  gsymEq :: forall a b.
SymEqArgs arity a b -> (:+:) a b a -> (:+:) a b b -> SymBool
gsymEq SymEqArgs arity a b
args (L1 a a
a1) (L1 a b
a2) = SymEqArgs arity a b -> a a -> a b -> SymBool
forall a b. SymEqArgs arity a b -> a a -> a b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq SymEqArgs arity a b
args a a
a1 a b
a2
  gsymEq SymEqArgs arity a b
args (R1 b a
b1) (R1 b b
b2) = SymEqArgs arity a b -> b a -> b b -> SymBool
forall a b. SymEqArgs arity a b -> b a -> b b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq SymEqArgs arity a b
args b a
b1 b b
b2
  gsymEq SymEqArgs arity a b
_ (:+:) a b a
_ (:+:) a b b
_ = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
False
  {-# INLINE gsymEq #-}

instance (GSymEq arity a) => GSymEq arity (M1 i c a) where
  gsymEq :: forall a b.
SymEqArgs arity a b -> M1 i c a a -> M1 i c a b -> SymBool
gsymEq SymEqArgs arity a b
args (M1 a a
a1) (M1 a b
a2) = SymEqArgs arity a b -> a a -> a b -> SymBool
forall a b. SymEqArgs arity a b -> a a -> a b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq SymEqArgs arity a b
args a a
a1 a b
a2
  {-# INLINE gsymEq #-}

instance (SymEq a) => GSymEq arity (K1 i a) where
  gsymEq :: forall a b. SymEqArgs arity a b -> K1 i a a -> K1 i a b -> SymBool
gsymEq SymEqArgs arity a b
_ (K1 a
a) (K1 a
b) = a
a a -> a -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== a
b
  {-# INLINE gsymEq #-}

instance GSymEq Arity1 Par1 where
  gsymEq :: forall a b. SymEqArgs Arity1 a b -> Par1 a -> Par1 b -> SymBool
gsymEq (SymEqArgs1 a -> b -> SymBool
e) (Par1 a
a) (Par1 b
b) = a -> b -> SymBool
e a
a b
b
  {-# INLINE gsymEq #-}

instance (SymEq1 f) => GSymEq Arity1 (Rec1 f) where
  gsymEq :: forall a b. SymEqArgs Arity1 a b -> Rec1 f a -> Rec1 f b -> SymBool
gsymEq (SymEqArgs1 a -> b -> SymBool
e) (Rec1 f a
a) (Rec1 f b
b) = (a -> b -> SymBool) -> f a -> f b -> SymBool
forall a b. (a -> b -> SymBool) -> f a -> f b -> SymBool
forall (f :: * -> *) a b.
SymEq1 f =>
(a -> b -> SymBool) -> f a -> f b -> SymBool
liftSymEq a -> b -> SymBool
e f a
a f b
b
  {-# INLINE gsymEq #-}

instance (SymEq1 f, GSymEq Arity1 g) => GSymEq Arity1 (f :.: g) where
  gsymEq :: forall a b.
SymEqArgs Arity1 a b -> (:.:) f g a -> (:.:) f g b -> SymBool
gsymEq SymEqArgs Arity1 a b
targs (Comp1 f (g a)
a) (Comp1 f (g b)
b) = (g a -> g b -> SymBool) -> f (g a) -> f (g b) -> SymBool
forall a b. (a -> b -> SymBool) -> f a -> f b -> SymBool
forall (f :: * -> *) a b.
SymEq1 f =>
(a -> b -> SymBool) -> f a -> f b -> SymBool
liftSymEq (SymEqArgs Arity1 a b -> g a -> g b -> SymBool
forall a b. SymEqArgs Arity1 a b -> g a -> g b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq SymEqArgs Arity1 a b
targs) f (g a)
a f (g b)
b
  {-# INLINE gsymEq #-}

instance (Generic a, GSymEq Arity0 (Rep a)) => SymEq (Default a) where
  Default a
l .== :: Default a -> Default a -> SymBool
.== Default a
r = a -> a -> SymBool
forall a. (Generic a, GSymEq Arity0 (Rep a)) => a -> a -> SymBool
genericSymEq a
l a
r
  {-# INLINE (.==) #-}

-- | Generic @('.==')@ function.
genericSymEq :: (Generic a, GSymEq Arity0 (Rep a)) => a -> a -> SymBool
genericSymEq :: forall a. (Generic a, GSymEq Arity0 (Rep a)) => a -> a -> SymBool
genericSymEq a
l a
r = SymEqArgs Arity0 Any Any -> Rep a Any -> Rep a Any -> SymBool
forall a b. SymEqArgs Arity0 a b -> Rep a a -> Rep a b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq SymEqArgs Arity0 Any Any
forall _ _. SymEqArgs Arity0 _ _
SymEqArgs0 (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
l) (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
r)
{-# INLINE genericSymEq #-}

instance (Generic1 f, GSymEq Arity1 (Rep1 f), SymEq a) => SymEq (Default1 f a) where
  .== :: Default1 f a -> Default1 f a -> SymBool
(.==) = Default1 f a -> Default1 f a -> SymBool
forall a (f :: * -> *).
(SymEq a, SymEq1 f) =>
f a -> f a -> SymBool
symEq1
  {-# INLINE (.==) #-}

instance (Generic1 f, GSymEq Arity1 (Rep1 f)) => SymEq1 (Default1 f) where
  liftSymEq :: forall a b.
(a -> b -> SymBool) -> Default1 f a -> Default1 f b -> SymBool
liftSymEq a -> b -> SymBool
f (Default1 f a
l) (Default1 f b
r) = (a -> b -> SymBool) -> f a -> f b -> SymBool
forall (f :: * -> *) a b.
(Generic1 f, GSymEq Arity1 (Rep1 f)) =>
(a -> b -> SymBool) -> f a -> f b -> SymBool
genericLiftSymEq a -> b -> SymBool
f f a
l f b
r
  {-# INLINE liftSymEq #-}

-- | Generic 'liftSymEq' function.
genericLiftSymEq ::
  (Generic1 f, GSymEq Arity1 (Rep1 f)) =>
  (a -> b -> SymBool) ->
  f a ->
  f b ->
  SymBool
genericLiftSymEq :: forall (f :: * -> *) a b.
(Generic1 f, GSymEq Arity1 (Rep1 f)) =>
(a -> b -> SymBool) -> f a -> f b -> SymBool
genericLiftSymEq a -> b -> SymBool
f f a
l f b
r = SymEqArgs Arity1 a b -> Rep1 f a -> Rep1 f b -> SymBool
forall a b. SymEqArgs Arity1 a b -> Rep1 f a -> Rep1 f b -> SymBool
forall arity (f :: * -> *) a b.
GSymEq arity f =>
SymEqArgs arity a b -> f a -> f b -> SymBool
gsymEq ((a -> b -> SymBool) -> SymEqArgs Arity1 a b
forall a b. (a -> b -> SymBool) -> SymEqArgs Arity1 a b
SymEqArgs1 a -> b -> SymBool
f) (f a -> Rep1 f a
forall a. f a -> Rep1 f a
forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1 f a
l) (f b -> Rep1 f b
forall a. f a -> Rep1 f a
forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1 f b
r)
{-# INLINE genericLiftSymEq #-}