{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Strict #-}

-- |
-- Module      :   Grisette.Internal.Backend.QuantifiedStack
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Backend.QuantifiedStack
  ( QuantifiedSymbols (..),
    QuantifiedStack,
    addQuantified,
    lookupQuantified,
    emptyQuantifiedSymbols,
    addQuantifiedSymbol,
    isQuantifiedSymbol,
    emptyQuantifiedStack,
  )
where

import Data.Dynamic (Dynamic)
import qualified Data.HashMap.Strict as M
import qualified Data.HashSet as S
import Data.Hashable (Hashable (hashWithSalt))
import GHC.Stack (HasCallStack)
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( IsSymbolKind,
    SomeTypedConstantSymbol,
    SomeTypedSymbol,
    SupportedPrim (castTypedSymbol),
    TypedConstantSymbol,
    TypedSymbol,
    castSomeTypedSymbol,
    someTypedSymbol,
  )

-- | A set of quantified symbols.
newtype QuantifiedSymbols = QuantifiedSymbols
  { QuantifiedSymbols -> HashSet SomeTypedConstantSymbol
_symbols :: S.HashSet SomeTypedConstantSymbol
  }
  deriving (Int -> QuantifiedSymbols -> ShowS
[QuantifiedSymbols] -> ShowS
QuantifiedSymbols -> String
(Int -> QuantifiedSymbols -> ShowS)
-> (QuantifiedSymbols -> String)
-> ([QuantifiedSymbols] -> ShowS)
-> Show QuantifiedSymbols
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> QuantifiedSymbols -> ShowS
showsPrec :: Int -> QuantifiedSymbols -> ShowS
$cshow :: QuantifiedSymbols -> String
show :: QuantifiedSymbols -> String
$cshowList :: [QuantifiedSymbols] -> ShowS
showList :: [QuantifiedSymbols] -> ShowS
Show)

-- | An empty set of quantified symbols.
emptyQuantifiedSymbols :: QuantifiedSymbols
emptyQuantifiedSymbols :: QuantifiedSymbols
emptyQuantifiedSymbols = HashSet SomeTypedConstantSymbol -> QuantifiedSymbols
QuantifiedSymbols HashSet SomeTypedConstantSymbol
forall a. HashSet a
S.empty

-- | Add a quantified symbol to the set.
addQuantifiedSymbol ::
  TypedConstantSymbol a -> QuantifiedSymbols -> QuantifiedSymbols
addQuantifiedSymbol :: forall a.
TypedConstantSymbol a -> QuantifiedSymbols -> QuantifiedSymbols
addQuantifiedSymbol TypedConstantSymbol a
s (QuantifiedSymbols HashSet SomeTypedConstantSymbol
t) =
  HashSet SomeTypedConstantSymbol -> QuantifiedSymbols
QuantifiedSymbols (SomeTypedConstantSymbol
-> HashSet SomeTypedConstantSymbol
-> HashSet SomeTypedConstantSymbol
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
S.insert (TypedConstantSymbol a -> SomeTypedConstantSymbol
forall (knd :: SymbolKind) t.
TypedSymbol knd t -> SomeTypedSymbol knd
someTypedSymbol TypedConstantSymbol a
s) HashSet SomeTypedConstantSymbol
t)

-- | Check if a symbol is quantified.
isQuantifiedSymbol ::
  (SupportedPrim a, IsSymbolKind knd) =>
  TypedSymbol knd a ->
  QuantifiedSymbols ->
  Bool
isQuantifiedSymbol :: forall a (knd :: SymbolKind).
(SupportedPrim a, IsSymbolKind knd) =>
TypedSymbol knd a -> QuantifiedSymbols -> Bool
isQuantifiedSymbol TypedSymbol knd a
s (QuantifiedSymbols HashSet SomeTypedConstantSymbol
t) =
  case TypedSymbol knd a -> Maybe (TypedSymbol 'ConstantKind a)
forall t (knd' :: SymbolKind) (knd :: SymbolKind).
(SupportedPrim t, IsSymbolKind knd') =>
TypedSymbol knd t -> Maybe (TypedSymbol knd' t)
forall (knd' :: SymbolKind) (knd :: SymbolKind).
IsSymbolKind knd' =>
TypedSymbol knd a -> Maybe (TypedSymbol knd' a)
castTypedSymbol TypedSymbol knd a
s of
    Just TypedSymbol 'ConstantKind a
s' -> SomeTypedConstantSymbol -> HashSet SomeTypedConstantSymbol -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
S.member (TypedSymbol 'ConstantKind a -> SomeTypedConstantSymbol
forall (knd :: SymbolKind) t.
TypedSymbol knd t -> SomeTypedSymbol knd
someTypedSymbol TypedSymbol 'ConstantKind a
s') HashSet SomeTypedConstantSymbol
t
    Maybe (TypedSymbol 'ConstantKind a)
_ -> Bool
False

-- | A stack of quantified symbols.
newtype QuantifiedStack = QuantifiedStack
  {QuantifiedStack -> HashMap SomeTypedConstantSymbol Dynamic
_stack :: M.HashMap SomeTypedConstantSymbol Dynamic}

instance Eq QuantifiedStack where
  QuantifiedStack HashMap SomeTypedConstantSymbol Dynamic
s1 == :: QuantifiedStack -> QuantifiedStack -> Bool
== QuantifiedStack HashMap SomeTypedConstantSymbol Dynamic
s2 = HashMap SomeTypedConstantSymbol Dynamic
-> HashSet SomeTypedConstantSymbol
forall k a. HashMap k a -> HashSet k
M.keysSet HashMap SomeTypedConstantSymbol Dynamic
s1 HashSet SomeTypedConstantSymbol
-> HashSet SomeTypedConstantSymbol -> Bool
forall a. Eq a => a -> a -> Bool
== HashMap SomeTypedConstantSymbol Dynamic
-> HashSet SomeTypedConstantSymbol
forall k a. HashMap k a -> HashSet k
M.keysSet HashMap SomeTypedConstantSymbol Dynamic
s2

instance Hashable QuantifiedStack where
  hashWithSalt :: Int -> QuantifiedStack -> Int
hashWithSalt Int
s (QuantifiedStack HashMap SomeTypedConstantSymbol Dynamic
t) = Int -> [SomeTypedConstantSymbol] -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s (HashMap SomeTypedConstantSymbol Dynamic
-> [SomeTypedConstantSymbol]
forall k v. HashMap k v -> [k]
M.keys HashMap SomeTypedConstantSymbol Dynamic
t)

-- | An empty stack of quantified symbols.
emptyQuantifiedStack :: QuantifiedStack
emptyQuantifiedStack :: QuantifiedStack
emptyQuantifiedStack = HashMap SomeTypedConstantSymbol Dynamic -> QuantifiedStack
QuantifiedStack HashMap SomeTypedConstantSymbol Dynamic
forall k v. HashMap k v
M.empty

-- | Add a quantified symbol to the stack.
addQuantified ::
  TypedConstantSymbol a -> Dynamic -> QuantifiedStack -> QuantifiedStack
addQuantified :: forall a.
TypedConstantSymbol a
-> Dynamic -> QuantifiedStack -> QuantifiedStack
addQuantified TypedConstantSymbol a
s Dynamic
d (QuantifiedStack HashMap SomeTypedConstantSymbol Dynamic
t) =
  HashMap SomeTypedConstantSymbol Dynamic -> QuantifiedStack
QuantifiedStack (SomeTypedConstantSymbol
-> Dynamic
-> HashMap SomeTypedConstantSymbol Dynamic
-> HashMap SomeTypedConstantSymbol Dynamic
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert (TypedConstantSymbol a -> SomeTypedConstantSymbol
forall (knd :: SymbolKind) t.
TypedSymbol knd t -> SomeTypedSymbol knd
someTypedSymbol TypedConstantSymbol a
s) Dynamic
d HashMap SomeTypedConstantSymbol Dynamic
t)

-- | Look up a quantified symbol in the stack.
lookupQuantified ::
  (HasCallStack, IsSymbolKind knd) =>
  SomeTypedSymbol knd ->
  QuantifiedStack ->
  Maybe Dynamic
lookupQuantified :: forall (knd :: SymbolKind).
(HasCallStack, IsSymbolKind knd) =>
SomeTypedSymbol knd -> QuantifiedStack -> Maybe Dynamic
lookupQuantified SomeTypedSymbol knd
s (QuantifiedStack HashMap SomeTypedConstantSymbol Dynamic
t) =
  (SomeTypedConstantSymbol
-> HashMap SomeTypedConstantSymbol Dynamic -> Maybe Dynamic
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
`M.lookup` HashMap SomeTypedConstantSymbol Dynamic
t) (SomeTypedConstantSymbol -> Maybe Dynamic)
-> Maybe SomeTypedConstantSymbol -> Maybe Dynamic
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SomeTypedSymbol knd -> Maybe SomeTypedConstantSymbol
forall (knd' :: SymbolKind) (knd :: SymbolKind).
IsSymbolKind knd' =>
SomeTypedSymbol knd -> Maybe (SomeTypedSymbol knd')
castSomeTypedSymbol SomeTypedSymbol knd
s