{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.Quantifier
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.Quantifier
  ( forallSet,
    forallSym,
    existsSet,
    existsSym,
    forallFresh,
    existsFresh,
  )
where

import Data.Bifunctor (Bifunctor (first))
import qualified Data.HashSet as HS
import Data.List (sort)
import GHC.Stack (HasCallStack)
import Grisette.Internal.Core.Control.Monad.Union (Union, liftUnion)
import Grisette.Internal.Core.Data.Class.ExtractSym
  ( ExtractSym (extractSymMaybe),
  )
import Grisette.Internal.Core.Data.Class.GenSym
  ( Fresh,
    FreshT (FreshT),
    GenSym (fresh),
    MonadFresh,
    liftFresh,
  )
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.PlainUnion (simpleMerge)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, mrgSingle)
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( SomeTypedSymbol (SomeTypedSymbol),
    existsTerm,
    forallTerm,
  )
import Grisette.Internal.SymPrim.Prim.Model
  ( ConstantSymbolSet,
    SymbolSet (SymbolSet),
  )
import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool))

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend
-- >>> import Grisette.Lib.Base

-- | Forall quantifier over a set of constant symbols. Quantifier over functions
-- is not supported.
--
-- >>> let xsym = "x" :: TypedConstantSymbol Integer
-- >>> let ysym = "y" :: TypedConstantSymbol Integer
-- >>> let x = "x" :: SymInteger
-- >>> let y = "y" :: SymInteger
-- >>> forallSet (buildSymbolSet [xsym, ysym]) (x .== y)
-- (forall x :: Integer (forall y :: Integer (= x y)))
--
-- Only available with SBV 10.1.0 or later.
forallSet :: ConstantSymbolSet -> SymBool -> SymBool
forallSet :: ConstantSymbolSet -> SymBool -> SymBool
forallSet (SymbolSet HashSet (SomeTypedSymbol 'ConstantKind)
set) SymBool
b =
  (SomeTypedSymbol 'ConstantKind -> SymBool -> SymBool)
-> SymBool -> [SomeTypedSymbol 'ConstantKind] -> SymBool
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
    ( \(SomeTypedSymbol TypedSymbol 'ConstantKind t
s) (SymBool Term Bool
b') ->
        Term Bool -> SymBool
SymBool (Term Bool -> SymBool) -> Term Bool -> SymBool
forall a b. (a -> b) -> a -> b
$ TypedSymbol 'ConstantKind t -> Term Bool -> Term Bool
forall t. TypedSymbol 'ConstantKind t -> Term Bool -> Term Bool
forallTerm TypedSymbol 'ConstantKind t
s Term Bool
b'
    )
    SymBool
b
    ([SomeTypedSymbol 'ConstantKind] -> [SomeTypedSymbol 'ConstantKind]
forall a. Ord a => [a] -> [a]
sort ([SomeTypedSymbol 'ConstantKind]
 -> [SomeTypedSymbol 'ConstantKind])
-> [SomeTypedSymbol 'ConstantKind]
-> [SomeTypedSymbol 'ConstantKind]
forall a b. (a -> b) -> a -> b
$ HashSet (SomeTypedSymbol 'ConstantKind)
-> [SomeTypedSymbol 'ConstantKind]
forall a. HashSet a -> [a]
HS.toList HashSet (SomeTypedSymbol 'ConstantKind)
set)

-- | Forall quantifier over all symbolic constants in a value. Quantifier over
-- functions is not supported.
--
-- >>> let a = ["x", "y"] :: [SymInteger]
-- >>> forallSym a $ sum a .== 0
-- (forall x :: Integer (forall y :: Integer (= (+ x y) 0)))
--
-- Only available with sbv 10.1.0 or later.
forallSym :: (HasCallStack, ExtractSym a) => a -> SymBool -> SymBool
forallSym :: forall a. (HasCallStack, ExtractSym a) => a -> SymBool -> SymBool
forallSym a
s SymBool
b =
  case a -> Maybe ConstantSymbolSet
forall a (knd :: SymbolKind).
(ExtractSym a, IsSymbolKind knd) =>
a -> Maybe (SymbolSet knd)
forall (knd :: SymbolKind).
IsSymbolKind knd =>
a -> Maybe (SymbolSet knd)
extractSymMaybe a
s of
    Just ConstantSymbolSet
s' -> ConstantSymbolSet -> SymBool -> SymBool
forallSet ConstantSymbolSet
s' SymBool
b
    Maybe ConstantSymbolSet
Nothing ->
      [Char] -> SymBool
forall a. HasCallStack => [Char] -> a
error
        [Char]
"Cannot use forall here. Only non-function symbols can be quantified."

-- | Exists quantifier over a set of constant symbols. Quantifier over functions
-- is not supported.
--
-- >>> let xsym = "x" :: TypedConstantSymbol Integer
-- >>> let ysym = "y" :: TypedConstantSymbol Integer
-- >>> let x = "x" :: SymInteger
-- >>> let y = "y" :: SymInteger
-- >>> existsSet (buildSymbolSet [xsym, ysym]) (x .== y)
-- (exists x :: Integer (exists y :: Integer (= x y)))
--
-- Only available with SBV 10.1.0 or later.
existsSet :: ConstantSymbolSet -> SymBool -> SymBool
existsSet :: ConstantSymbolSet -> SymBool -> SymBool
existsSet (SymbolSet HashSet (SomeTypedSymbol 'ConstantKind)
set) SymBool
b =
  (SomeTypedSymbol 'ConstantKind -> SymBool -> SymBool)
-> SymBool -> [SomeTypedSymbol 'ConstantKind] -> SymBool
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
    ( \(SomeTypedSymbol TypedSymbol 'ConstantKind t
s) (SymBool Term Bool
b') ->
        Term Bool -> SymBool
SymBool (Term Bool -> SymBool) -> Term Bool -> SymBool
forall a b. (a -> b) -> a -> b
$ TypedSymbol 'ConstantKind t -> Term Bool -> Term Bool
forall t. TypedSymbol 'ConstantKind t -> Term Bool -> Term Bool
existsTerm TypedSymbol 'ConstantKind t
s Term Bool
b'
    )
    SymBool
b
    ([SomeTypedSymbol 'ConstantKind] -> [SomeTypedSymbol 'ConstantKind]
forall a. Ord a => [a] -> [a]
sort ([SomeTypedSymbol 'ConstantKind]
 -> [SomeTypedSymbol 'ConstantKind])
-> [SomeTypedSymbol 'ConstantKind]
-> [SomeTypedSymbol 'ConstantKind]
forall a b. (a -> b) -> a -> b
$ HashSet (SomeTypedSymbol 'ConstantKind)
-> [SomeTypedSymbol 'ConstantKind]
forall a. HashSet a -> [a]
HS.toList HashSet (SomeTypedSymbol 'ConstantKind)
set)

-- | Exists quantifier over all symbolic constants in a value. Quantifier over
-- functions is not supported.
--
-- >>> let a = ["x", "y"] :: [SymInteger]
-- >>> existsSym a $ sum a .== 0
-- (exists x :: Integer (exists y :: Integer (= (+ x y) 0)))
--
-- Only available with sbv 10.1.0 or later.
existsSym :: (HasCallStack, ExtractSym a) => a -> SymBool -> SymBool
existsSym :: forall a. (HasCallStack, ExtractSym a) => a -> SymBool -> SymBool
existsSym a
s SymBool
b =
  case a -> Maybe ConstantSymbolSet
forall a (knd :: SymbolKind).
(ExtractSym a, IsSymbolKind knd) =>
a -> Maybe (SymbolSet knd)
forall (knd :: SymbolKind).
IsSymbolKind knd =>
a -> Maybe (SymbolSet knd)
extractSymMaybe a
s of
    Just ConstantSymbolSet
s' -> ConstantSymbolSet -> SymBool -> SymBool
existsSet ConstantSymbolSet
s' SymBool
b
    Maybe ConstantSymbolSet
Nothing ->
      [Char] -> SymBool
forall a. HasCallStack => [Char] -> a
error
        [Char]
"Cannot use exists here. Only non-function symbols can be quantified."

freshTUnionToFreshUnion ::
  forall a.
  (Mergeable a) =>
  FreshT Union a ->
  Fresh (Union a)
freshTUnionToFreshUnion :: forall a. Mergeable a => FreshT Union a -> Fresh (Union a)
freshTUnionToFreshUnion (FreshT Identifier -> FreshIndex -> Union (a, FreshIndex)
v) =
  (Identifier -> FreshIndex -> Identity (Union a, FreshIndex))
-> FreshT Identity (Union a)
forall (m :: * -> *) a.
(Identifier -> FreshIndex -> m (a, FreshIndex)) -> FreshT m a
FreshT ((Identifier -> FreshIndex -> Identity (Union a, FreshIndex))
 -> FreshT Identity (Union a))
-> (Identifier -> FreshIndex -> Identity (Union a, FreshIndex))
-> FreshT Identity (Union a)
forall a b. (a -> b) -> a -> b
$ \Identifier
ident FreshIndex
index ->
    (Union a, FreshIndex) -> Identity (Union a, FreshIndex)
forall a. a -> Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Union a, FreshIndex) -> Identity (Union a, FreshIndex))
-> (Union a, FreshIndex) -> Identity (Union a, FreshIndex)
forall a b. (a -> b) -> a -> b
$ Union (Union a, FreshIndex) -> (Union a, FreshIndex)
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (Union (Union a, FreshIndex) -> (Union a, FreshIndex))
-> Union (Union a, FreshIndex) -> (Union a, FreshIndex)
forall a b. (a -> b) -> a -> b
$ (a -> Union a) -> (a, FreshIndex) -> (Union a, FreshIndex)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first a -> Union a
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle ((a, FreshIndex) -> (Union a, FreshIndex))
-> Union (a, FreshIndex) -> Union (Union a, FreshIndex)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Identifier -> FreshIndex -> Union (a, FreshIndex)
v Identifier
ident FreshIndex
index

-- | Forall quantifier over symbolic constants in a freshly generated value.
-- Quantifier over functions is not supported.
--
-- >>> :{
-- x :: Fresh SymBool
-- x = forallFresh () $ \(a :: SymBool) ->
--       existsFresh () $ \(b :: SymBool) ->
--         mrgReturn $ a .== b
-- :}
--
-- >>> runFresh x "x"
-- (forall x@0 :: Bool (exists x@1 :: Bool (= x@0 x@1)))
--
-- Only available with sbv 10.1.0 or later.
forallFresh ::
  ( HasCallStack,
    ExtractSym v,
    MonadFresh m,
    GenSym spec v,
    TryMerge m
  ) =>
  spec ->
  (v -> FreshT Union SymBool) ->
  m SymBool
forallFresh :: forall v (m :: * -> *) spec.
(HasCallStack, ExtractSym v, MonadFresh m, GenSym spec v,
 TryMerge m) =>
spec -> (v -> FreshT Union SymBool) -> m SymBool
forallFresh spec
spec v -> FreshT Union SymBool
f = do
  u <- spec -> m (Union v)
forall spec a (m :: * -> *).
(GenSym spec a, MonadFresh m) =>
spec -> m (Union a)
forall (m :: * -> *). MonadFresh m => spec -> m (Union v)
fresh spec
spec
  p <- liftFresh . fmap simpleMerge . freshTUnionToFreshUnion $ do
    liftUnion u >>= f
  mrgSingle $ forallSym u p

-- | Exists quantifier over symbolic constants in a freshly generated value.
-- Quantifier over functions is not supported.
--
-- >>> :{
-- x :: Fresh SymBool
-- x = forallFresh () $ \(a :: SymBool) ->
--       existsFresh () $ \(b :: SymBool) ->
--         mrgReturn $ a .== b
-- :}
--
-- >>> runFresh x "x"
-- (forall x@0 :: Bool (exists x@1 :: Bool (= x@0 x@1)))
--
-- Only available with sbv 10.1.0 or later.
existsFresh ::
  ( HasCallStack,
    ExtractSym v,
    MonadFresh m,
    GenSym spec v,
    TryMerge m
  ) =>
  spec ->
  (v -> FreshT Union SymBool) ->
  m SymBool
existsFresh :: forall v (m :: * -> *) spec.
(HasCallStack, ExtractSym v, MonadFresh m, GenSym spec v,
 TryMerge m) =>
spec -> (v -> FreshT Union SymBool) -> m SymBool
existsFresh spec
spec v -> FreshT Union SymBool
f = do
  u <- spec -> m (Union v)
forall spec a (m :: * -> *).
(GenSym spec a, MonadFresh m) =>
spec -> m (Union a)
forall (m :: * -> *). MonadFresh m => spec -> m (Union v)
fresh spec
spec
  p <- liftFresh . fmap simpleMerge . freshTUnionToFreshUnion $ do
    liftUnion u >>= f
  mrgSingle $ existsSym u p