{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Internal.Decl.SymPrim.AllSyms
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Internal.Decl.SymPrim.AllSyms
  ( -- * Get all symbolic primitive values in a value
    SomeSym (..),
    AllSyms (..),
    AllSyms1 (..),
    allSymsS1,
    AllSyms2 (..),
    allSymsS2,
    allSymsSize,
    symSize,
    symsSize,

    -- * Generic 'AllSyms'
    AllSymsArgs (..),
    GAllSyms (..),
    genericAllSymsS,
    genericLiftAllSymsS,
  )
where

import Data.Kind (Type)
import GHC.Generics
  ( Generic (Rep, from),
    Generic1 (Rep1, from1),
    K1 (K1),
    M1 (M1),
    Par1 (Par1),
    Rec1 (Rec1),
    U1,
    V1,
    (:.:) (Comp1),
    type (:*:) ((:*:)),
    type (:+:) (L1, R1),
  )
import Generics.Deriving
  ( Default (unDefault),
    Default1 (unDefault1),
  )
import Grisette.Internal.SymPrim.Prim.SomeTerm
  ( SomeTerm (SomeTerm),
  )
import Grisette.Internal.SymPrim.Prim.Term
  ( LinkedRep (underlyingTerm),
    pformatTerm,
  )
import Grisette.Internal.SymPrim.Prim.TermUtils
  ( someTermsSize,
    termSize,
    termsSize,
  )
import Grisette.Internal.Utils.Derive (Arity0, Arity1)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Data.Proxy

-- | Some symbolic value with 'LinkedRep' constraint.
data SomeSym where
  SomeSym :: (LinkedRep con sym) => sym -> SomeSym

instance Show SomeSym where
  show :: SomeSym -> String
show (SomeSym sym
s) = Term con -> String
forall t. Term t -> String
pformatTerm (Term con -> String) -> Term con -> String
forall a b. (a -> b) -> a -> b
$ sym -> Term con
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm sym
s

-- | Extract all symbolic primitive values that are represented as SMT terms.
--
-- >>> allSyms (["a" + 1 :: SymInteger, -"b"], "c" :: SymBool)
-- [(+ 1 a),(- b),c]
--
-- This is usually used for getting a statistical summary of the size of
-- a symbolic value with 'allSymsSize'.
--
-- __Note:__ This type class can be derived for algebraic data types. You may
-- need the @DerivingVia@ and @DerivingStrategies@ extenstions.
--
-- > data X = ... deriving Generic deriving AllSyms via (Default X)
class AllSyms a where
  -- | Convert a value to a list of symbolic primitive values. It should
  -- prepend to an existing list of symbolic primitive values.
  allSymsS :: a -> [SomeSym] -> [SomeSym]
  allSymsS a
a [SomeSym]
l = a -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym]
allSyms a
a [SomeSym] -> [SomeSym] -> [SomeSym]
forall a. [a] -> [a] -> [a]
++ [SomeSym]
l

  -- | Specialized 'allSymsS' that prepends to an empty list.
  allSyms :: a -> [SomeSym]
  allSyms a
a = a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS a
a []

  {-# MINIMAL allSymsS | allSyms #-}

-- | Get the sum of the sizes of a list of symbolic terms.
-- Duplicate sub-terms are counted for only once.
--
-- >>> symsSize [1, "a" :: SymInteger, "a" + 1 :: SymInteger]
-- 3
symsSize :: forall con sym. (LinkedRep con sym) => [sym] -> Int
symsSize :: forall con sym. LinkedRep con sym => [sym] -> Int
symsSize = [Term con] -> Int
forall a. [Term a] -> Int
termsSize ([Term con] -> Int) -> ([sym] -> [Term con]) -> [sym] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (sym -> Term con) -> [sym] -> [Term con]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm @con)
{-# INLINE symsSize #-}

-- | Get the size of a symbolic term.
-- Duplicate sub-terms are counted for only once.
--
-- >>> symSize (1 :: SymInteger)
-- 1
-- >>> symSize ("a" :: SymInteger)
-- 1
-- >>> symSize ("a" + 1 :: SymInteger)
-- 3
-- >>> symSize (("a" + 1) * ("a" + 1) :: SymInteger)
-- 4
symSize :: forall con sym. (LinkedRep con sym) => sym -> Int
symSize :: forall con sym. LinkedRep con sym => sym -> Int
symSize = Term con -> Int
forall a. Term a -> Int
termSize (Term con -> Int) -> (sym -> Term con) -> sym -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm @con
{-# INLINE symSize #-}

someUnderlyingTerm :: SomeSym -> SomeTerm
someUnderlyingTerm :: SomeSym -> SomeTerm
someUnderlyingTerm (SomeSym sym
s) = Term con -> SomeTerm
forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm (Term con -> SomeTerm) -> Term con -> SomeTerm
forall a b. (a -> b) -> a -> b
$ sym -> Term con
forall con sym. LinkedRep con sym => sym -> Term con
underlyingTerm sym
s

someSymsSize :: [SomeSym] -> Int
someSymsSize :: [SomeSym] -> Int
someSymsSize = [SomeTerm] -> Int
someTermsSize ([SomeTerm] -> Int)
-> ([SomeSym] -> [SomeTerm]) -> [SomeSym] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SomeSym -> SomeTerm) -> [SomeSym] -> [SomeTerm]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SomeSym -> SomeTerm
someUnderlyingTerm
{-# INLINE someSymsSize #-}

-- | Get the total size of symbolic terms in a value.
-- Duplicate sub-terms are counted for only once.
--
-- >>> allSymsSize ("a" :: SymInteger, "a" + "b" :: SymInteger, ("a" + "b") * "c" :: SymInteger)
-- 5
--
-- The 5 terms are @a@, @b@, @(+ a b)@, @c@, and @(* (+ a b) c)@.
allSymsSize :: (AllSyms a) => a -> Int
allSymsSize :: forall a. AllSyms a => a -> Int
allSymsSize = [SomeSym] -> Int
someSymsSize ([SomeSym] -> Int) -> (a -> [SomeSym]) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym]
allSyms

-- | Lifting of the 'AllSyms' class to unary type constructors.
class (forall a. (AllSyms a) => AllSyms (f a)) => AllSyms1 f where
  -- | Lift the 'allSymsS' function to unary type constructors.
  liftAllSymsS :: (a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]

-- | Lift the standard 'allSymsS' function to unary type constructors.
allSymsS1 :: (AllSyms1 f, AllSyms a) => f a -> [SomeSym] -> [SomeSym]
allSymsS1 :: forall (f :: * -> *) a.
(AllSyms1 f, AllSyms a) =>
f a -> [SomeSym] -> [SomeSym]
allSymsS1 = (a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
forall a.
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
forall (f :: * -> *) a.
AllSyms1 f =>
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
liftAllSymsS a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS
{-# INLINE allSymsS1 #-}

-- | Lifting of the 'AllSyms' class to binary type constructors.
class (forall a. (AllSyms a) => AllSyms1 (f a)) => AllSyms2 f where
  -- | Lift the 'allSymsS' function to binary type constructors.
  liftAllSymsS2 ::
    (a -> [SomeSym] -> [SomeSym]) ->
    (b -> [SomeSym] -> [SomeSym]) ->
    f a b ->
    [SomeSym] ->
    [SomeSym]

-- | Lift the standard 'allSymsS' function to binary type constructors.
allSymsS2 ::
  (AllSyms2 f, AllSyms a, AllSyms b) => f a b -> [SomeSym] -> [SomeSym]
allSymsS2 :: forall (f :: * -> * -> *) a b.
(AllSyms2 f, AllSyms a, AllSyms b) =>
f a b -> [SomeSym] -> [SomeSym]
allSymsS2 = (a -> [SomeSym] -> [SomeSym])
-> (b -> [SomeSym] -> [SomeSym]) -> f a b -> [SomeSym] -> [SomeSym]
forall a b.
(a -> [SomeSym] -> [SomeSym])
-> (b -> [SomeSym] -> [SomeSym]) -> f a b -> [SomeSym] -> [SomeSym]
forall (f :: * -> * -> *) a b.
AllSyms2 f =>
(a -> [SomeSym] -> [SomeSym])
-> (b -> [SomeSym] -> [SomeSym]) -> f a b -> [SomeSym] -> [SomeSym]
liftAllSymsS2 a -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS b -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS
{-# INLINE allSymsS2 #-}

-- Derivation

-- | The arguments to the generic 'AllSyms' function.
data family AllSymsArgs arity a :: Type

data instance AllSymsArgs Arity0 _ = AllSymsArgs0

newtype instance AllSymsArgs Arity1 a
  = AllSymsArgs1 (a -> [SomeSym] -> [SomeSym])

-- | The class of types that can generically extract all symbolic primitives.
class GAllSyms arity f where
  gallSymsS :: AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]

instance GAllSyms arity V1 where
  gallSymsS :: forall a. AllSymsArgs arity a -> V1 a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
_ V1 a
_ = [SomeSym] -> [SomeSym]
forall a. a -> a
id

instance GAllSyms arity U1 where
  gallSymsS :: forall a. AllSymsArgs arity a -> U1 a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
_ U1 a
_ = [SomeSym] -> [SomeSym]
forall a. a -> a
id

instance (AllSyms c) => GAllSyms arity (K1 i c) where
  gallSymsS :: forall a. AllSymsArgs arity a -> K1 i c a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
_ (K1 c
x) = c -> [SomeSym] -> [SomeSym]
forall a. AllSyms a => a -> [SomeSym] -> [SomeSym]
allSymsS c
x

instance (GAllSyms arity a) => GAllSyms arity (M1 i c a) where
  gallSymsS :: forall a.
AllSymsArgs arity a -> M1 i c a a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args (M1 a a
x) = AllSymsArgs arity a -> a a -> [SomeSym] -> [SomeSym]
forall a. AllSymsArgs arity a -> a a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args a a
x

instance (GAllSyms arity a, GAllSyms arity b) => GAllSyms arity (a :+: b) where
  gallSymsS :: forall a.
AllSymsArgs arity a -> (:+:) a b a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args (L1 a a
l) = AllSymsArgs arity a -> a a -> [SomeSym] -> [SomeSym]
forall a. AllSymsArgs arity a -> a a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args a a
l
  gallSymsS AllSymsArgs arity a
args (R1 b a
r) = AllSymsArgs arity a -> b a -> [SomeSym] -> [SomeSym]
forall a. AllSymsArgs arity a -> b a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args b a
r

instance (GAllSyms arity a, GAllSyms arity b) => GAllSyms arity (a :*: b) where
  gallSymsS :: forall a.
AllSymsArgs arity a -> (:*:) a b a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args (a a
a :*: b a
b) = AllSymsArgs arity a -> a a -> [SomeSym] -> [SomeSym]
forall a. AllSymsArgs arity a -> a a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args a a
a ([SomeSym] -> [SomeSym])
-> ([SomeSym] -> [SomeSym]) -> [SomeSym] -> [SomeSym]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllSymsArgs arity a -> b a -> [SomeSym] -> [SomeSym]
forall a. AllSymsArgs arity a -> b a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs arity a
args b a
b

instance GAllSyms Arity1 Par1 where
  gallSymsS :: forall a. AllSymsArgs Arity1 a -> Par1 a -> [SomeSym] -> [SomeSym]
gallSymsS (AllSymsArgs1 a -> [SomeSym] -> [SomeSym]
f) (Par1 a
x) = a -> [SomeSym] -> [SomeSym]
f a
x

instance (AllSyms1 f) => GAllSyms Arity1 (Rec1 f) where
  gallSymsS :: forall a.
AllSymsArgs Arity1 a -> Rec1 f a -> [SomeSym] -> [SomeSym]
gallSymsS (AllSymsArgs1 a -> [SomeSym] -> [SomeSym]
f) (Rec1 f a
x) = (a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
forall a.
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
forall (f :: * -> *) a.
AllSyms1 f =>
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
liftAllSymsS a -> [SomeSym] -> [SomeSym]
f f a
x

instance (AllSyms1 f, GAllSyms Arity1 g) => GAllSyms Arity1 (f :.: g) where
  gallSymsS :: forall a.
AllSymsArgs Arity1 a -> (:.:) f g a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs Arity1 a
targs (Comp1 f (g a)
x) = (g a -> [SomeSym] -> [SomeSym])
-> f (g a) -> [SomeSym] -> [SomeSym]
forall a.
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
forall (f :: * -> *) a.
AllSyms1 f =>
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
liftAllSymsS (AllSymsArgs Arity1 a -> g a -> [SomeSym] -> [SomeSym]
forall a. AllSymsArgs Arity1 a -> g a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs Arity1 a
targs) f (g a)
x

-- | Generic 'allSymsS' function.
genericAllSymsS ::
  (Generic a, GAllSyms Arity0 (Rep a)) =>
  a ->
  [SomeSym] ->
  [SomeSym]
genericAllSymsS :: forall a.
(Generic a, GAllSyms Arity0 (Rep a)) =>
a -> [SomeSym] -> [SomeSym]
genericAllSymsS a
x = AllSymsArgs Arity0 Any -> Rep a Any -> [SomeSym] -> [SomeSym]
forall a. AllSymsArgs Arity0 a -> Rep a a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS AllSymsArgs Arity0 Any
forall _. AllSymsArgs Arity0 _
AllSymsArgs0 (a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
x)
{-# INLINE genericAllSymsS #-}

-- | Generic 'liftAllSymsS' function.
genericLiftAllSymsS ::
  (Generic1 f, GAllSyms Arity1 (Rep1 f)) =>
  (a -> [SomeSym] -> [SomeSym]) ->
  f a ->
  [SomeSym] ->
  [SomeSym]
genericLiftAllSymsS :: forall (f :: * -> *) a.
(Generic1 f, GAllSyms Arity1 (Rep1 f)) =>
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
genericLiftAllSymsS a -> [SomeSym] -> [SomeSym]
f f a
x = AllSymsArgs Arity1 a -> Rep1 f a -> [SomeSym] -> [SomeSym]
forall a.
AllSymsArgs Arity1 a -> Rep1 f a -> [SomeSym] -> [SomeSym]
forall arity (f :: * -> *) a.
GAllSyms arity f =>
AllSymsArgs arity a -> f a -> [SomeSym] -> [SomeSym]
gallSymsS ((a -> [SomeSym] -> [SomeSym]) -> AllSymsArgs Arity1 a
forall a. (a -> [SomeSym] -> [SomeSym]) -> AllSymsArgs Arity1 a
AllSymsArgs1 a -> [SomeSym] -> [SomeSym]
f) (f a -> Rep1 f a
forall a. f a -> Rep1 f a
forall k (f :: k -> *) (a :: k). Generic1 f => f a -> Rep1 f a
from1 f a
x)
{-# INLINE genericLiftAllSymsS #-}

instance (Generic a, GAllSyms Arity0 (Rep a)) => AllSyms (Default a) where
  allSymsS :: Default a -> [SomeSym] -> [SomeSym]
allSymsS = a -> [SomeSym] -> [SomeSym]
forall a.
(Generic a, GAllSyms Arity0 (Rep a)) =>
a -> [SomeSym] -> [SomeSym]
genericAllSymsS (a -> [SomeSym] -> [SomeSym])
-> (Default a -> a) -> Default a -> [SomeSym] -> [SomeSym]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Default a -> a
forall a. Default a -> a
unDefault
  {-# INLINE allSymsS #-}

instance
  (Generic1 f, GAllSyms Arity1 (Rep1 f), AllSyms a) =>
  AllSyms (Default1 f a)
  where
  allSymsS :: Default1 f a -> [SomeSym] -> [SomeSym]
allSymsS = Default1 f a -> [SomeSym] -> [SomeSym]
forall (f :: * -> *) a.
(AllSyms1 f, AllSyms a) =>
f a -> [SomeSym] -> [SomeSym]
allSymsS1
  {-# INLINE allSymsS #-}

instance (Generic1 f, GAllSyms Arity1 (Rep1 f)) => AllSyms1 (Default1 f) where
  liftAllSymsS :: forall a.
(a -> [SomeSym] -> [SomeSym])
-> Default1 f a -> [SomeSym] -> [SomeSym]
liftAllSymsS a -> [SomeSym] -> [SomeSym]
f = (a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
forall (f :: * -> *) a.
(Generic1 f, GAllSyms Arity1 (Rep1 f)) =>
(a -> [SomeSym] -> [SomeSym]) -> f a -> [SomeSym] -> [SomeSym]
genericLiftAllSymsS a -> [SomeSym] -> [SomeSym]
f (f a -> [SomeSym] -> [SomeSym])
-> (Default1 f a -> f a) -> Default1 f a -> [SomeSym] -> [SomeSym]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Default1 f a -> f a
forall (f :: * -> *) a. Default1 f a -> f a
unDefault1
  {-# INLINE liftAllSymsS #-}