{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Strict #-}

-- |
-- Module      :   Grisette.Internal.Backend.SymBiMap
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Backend.SymBiMap
  ( SymBiMap (..),
    emptySymBiMap,
    sizeBiMap,
    addBiMap,
    addBiMapIntermediate,
    findStringToSymbol,
    lookupTerm,
    attachNextQuantifiedSymbolInfo,
  )
where

import Data.Dynamic (Dynamic)
import qualified Data.HashMap.Strict as M
import GHC.Stack (HasCallStack)
import Grisette.Internal.Backend.QuantifiedStack (QuantifiedStack)
import Grisette.Internal.Core.Data.SExpr (SExpr (Atom, List, NumberAtom))
import Grisette.Internal.Core.Data.Symbol
  ( mapIdentifier,
    mapMetadata,
  )
import Grisette.Internal.SymPrim.Prim.SomeTerm
  ( SomeTerm,
  )
import Grisette.Internal.SymPrim.Prim.Term
  ( IsSymbolKind,
    SomeTypedAnySymbol,
    SomeTypedSymbol,
    TypedConstantSymbol,
    TypedSymbol (unTypedSymbol),
    castSomeTypedSymbol,
    typedConstantSymbol,
    pattern SupportedConstantTypedSymbol,
  )

-- | A bidirectional map between symbolic Grisette terms and sbv terms.
data SymBiMap = SymBiMap
  { SymBiMap -> HashMap SomeTerm (QuantifiedStack -> Dynamic)
biMapToSBV :: M.HashMap SomeTerm (QuantifiedStack -> Dynamic),
    SymBiMap -> Int
biMapSize :: Int,
    SymBiMap -> HashMap String SomeTypedAnySymbol
biMapFromSBV :: M.HashMap String SomeTypedAnySymbol,
    SymBiMap -> Int
quantifiedSymbolNum :: Int
  }

instance Show SymBiMap where
  show :: SymBiMap -> String
show (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t Int
s HashMap String SomeTypedAnySymbol
f Int
_) =
    String
"SymBiMap { size: "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
s
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", toSBV: "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ [SomeTerm] -> String
forall a. Show a => a -> String
show (HashMap SomeTerm (QuantifiedStack -> Dynamic) -> [SomeTerm]
forall k v. HashMap k v -> [k]
M.keys HashMap SomeTerm (QuantifiedStack -> Dynamic)
t)
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", fromSBV: "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ [(String, SomeTypedAnySymbol)] -> String
forall a. Show a => a -> String
show (HashMap String SomeTypedAnySymbol -> [(String, SomeTypedAnySymbol)]
forall k v. HashMap k v -> [(k, v)]
M.toList HashMap String SomeTypedAnySymbol
f)
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" }"

-- | Information about a quantified symbol.
-- newtype QuantifiedSymbolInfo = QuantifiedSymbolInfo Int
--   deriving (Generic, Ord, Eq, Show, Hashable, Lift, NFData)
nextQuantifiedSymbolInfo :: SymBiMap -> (SymBiMap, SExpr -> SExpr)
nextQuantifiedSymbolInfo :: SymBiMap -> (SymBiMap, SExpr -> SExpr)
nextQuantifiedSymbolInfo (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t Int
s HashMap String SomeTypedAnySymbol
f Int
num) =
  ( HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> Int -> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t Int
s HashMap String SomeTypedAnySymbol
f (Int
num Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1),
    \SExpr
meta ->
      [SExpr] -> SExpr
List
        [ Text -> SExpr
Atom Text
"grisette-quantified",
          Integer -> SExpr
NumberAtom (Integer -> SExpr) -> Integer -> SExpr
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num,
          SExpr
meta
        ]
  )

attachQuantifiedSymbolInfo ::
  (SExpr -> SExpr) -> TypedConstantSymbol a -> TypedConstantSymbol a
attachQuantifiedSymbolInfo :: forall a.
(SExpr -> SExpr) -> TypedConstantSymbol a -> TypedConstantSymbol a
attachQuantifiedSymbolInfo SExpr -> SExpr
info tsym :: TypedConstantSymbol a
tsym@TypedConstantSymbol a
SupportedConstantTypedSymbol =
  Symbol -> TypedConstantSymbol a
forall t.
SupportedNonFuncPrim t =>
Symbol -> TypedSymbol 'ConstantKind t
typedConstantSymbol (Symbol -> TypedConstantSymbol a)
-> Symbol -> TypedConstantSymbol a
forall a b. (a -> b) -> a -> b
$
    (Identifier -> Identifier) -> Symbol -> Symbol
mapIdentifier ((SExpr -> SExpr) -> Identifier -> Identifier
forall a. AsMetadata a => (SExpr -> a) -> Identifier -> Identifier
mapMetadata SExpr -> SExpr
info) (Symbol -> Symbol) -> Symbol -> Symbol
forall a b. (a -> b) -> a -> b
$
      TypedConstantSymbol a -> Symbol
forall t (knd :: SymbolKind). TypedSymbol knd t -> Symbol
unTypedSymbol TypedConstantSymbol a
tsym
attachQuantifiedSymbolInfo SExpr -> SExpr
_ TypedConstantSymbol a
_ = String -> TypedConstantSymbol a
forall a. HasCallStack => String -> a
error String
"Should not happen"

-- | Attach the next quantified symbol info to a symbol.
attachNextQuantifiedSymbolInfo ::
  SymBiMap -> TypedConstantSymbol a -> (SymBiMap, TypedConstantSymbol a)
attachNextQuantifiedSymbolInfo :: forall a.
SymBiMap
-> TypedConstantSymbol a -> (SymBiMap, TypedConstantSymbol a)
attachNextQuantifiedSymbolInfo SymBiMap
m TypedConstantSymbol a
s =
  let (SymBiMap
m', SExpr -> SExpr
info) = SymBiMap -> (SymBiMap, SExpr -> SExpr)
nextQuantifiedSymbolInfo SymBiMap
m
   in (SymBiMap
m', (SExpr -> SExpr) -> TypedConstantSymbol a -> TypedConstantSymbol a
forall a.
(SExpr -> SExpr) -> TypedConstantSymbol a -> TypedConstantSymbol a
attachQuantifiedSymbolInfo SExpr -> SExpr
info TypedConstantSymbol a
s)

-- | An empty bidirectional map.
emptySymBiMap :: SymBiMap
emptySymBiMap :: SymBiMap
emptySymBiMap = HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> Int -> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
forall k v. HashMap k v
M.empty Int
0 HashMap String SomeTypedAnySymbol
forall k v. HashMap k v
M.empty Int
0

-- | The size of the bidirectional map.
sizeBiMap :: SymBiMap -> Int
sizeBiMap :: SymBiMap -> Int
sizeBiMap = SymBiMap -> Int
biMapSize

-- | Add a new entry to the bidirectional map.
addBiMap ::
  (HasCallStack) =>
  SomeTerm ->
  Dynamic ->
  String ->
  SomeTypedSymbol knd ->
  SymBiMap ->
  SymBiMap
addBiMap :: forall (knd :: SymbolKind).
HasCallStack =>
SomeTerm
-> Dynamic -> String -> SomeTypedSymbol knd -> SymBiMap -> SymBiMap
addBiMap SomeTerm
s Dynamic
d String
n SomeTypedSymbol knd
sb (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t Int
sz HashMap String SomeTypedAnySymbol
f Int
num) =
  case SomeTypedSymbol knd -> Maybe SomeTypedAnySymbol
forall (knd' :: SymbolKind) (knd :: SymbolKind).
IsSymbolKind knd' =>
SomeTypedSymbol knd -> Maybe (SomeTypedSymbol knd')
castSomeTypedSymbol SomeTypedSymbol knd
sb of
    Just SomeTypedAnySymbol
sb' -> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> Int -> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap (SomeTerm
-> (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert SomeTerm
s (Dynamic -> QuantifiedStack -> Dynamic
forall a b. a -> b -> a
const Dynamic
d) HashMap SomeTerm (QuantifiedStack -> Dynamic)
t) (Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (String
-> SomeTypedAnySymbol
-> HashMap String SomeTypedAnySymbol
-> HashMap String SomeTypedAnySymbol
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert String
n SomeTypedAnySymbol
sb' HashMap String SomeTypedAnySymbol
f) Int
num
    Maybe SomeTypedAnySymbol
_ -> String -> SymBiMap
forall a. HasCallStack => String -> a
error String
"Casting to AnySymbol, should not fail"

-- | Add a new entry to the bidirectional map for intermediate values.
addBiMapIntermediate ::
  (HasCallStack) => SomeTerm -> (QuantifiedStack -> Dynamic) -> SymBiMap -> SymBiMap
addBiMapIntermediate :: HasCallStack =>
SomeTerm -> (QuantifiedStack -> Dynamic) -> SymBiMap -> SymBiMap
addBiMapIntermediate SomeTerm
s QuantifiedStack -> Dynamic
d (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t Int
sz HashMap String SomeTypedAnySymbol
f Int
num) =
  HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> Int -> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap (SomeTerm
-> (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert SomeTerm
s QuantifiedStack -> Dynamic
d HashMap SomeTerm (QuantifiedStack -> Dynamic)
t) (Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) HashMap String SomeTypedAnySymbol
f Int
num

-- | Find a symbolic Grisette term from a string.
findStringToSymbol :: (IsSymbolKind knd) => String -> SymBiMap -> Maybe (SomeTypedSymbol knd)
findStringToSymbol :: forall (knd :: SymbolKind).
IsSymbolKind knd =>
String -> SymBiMap -> Maybe (SomeTypedSymbol knd)
findStringToSymbol String
s (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
_ Int
_ HashMap String SomeTypedAnySymbol
f Int
_) = do
  r <- String
-> HashMap String SomeTypedAnySymbol -> Maybe SomeTypedAnySymbol
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup String
s HashMap String SomeTypedAnySymbol
f
  castSomeTypedSymbol r

-- | Look up an sbv value with a symbolic Grisette term in the bidirectional
-- map.
lookupTerm :: (HasCallStack) => SomeTerm -> SymBiMap -> Maybe (QuantifiedStack -> Dynamic)
lookupTerm :: HasCallStack =>
SomeTerm -> SymBiMap -> Maybe (QuantifiedStack -> Dynamic)
lookupTerm SomeTerm
t SymBiMap
m = SomeTerm
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> Maybe (QuantifiedStack -> Dynamic)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup SomeTerm
t (SymBiMap -> HashMap SomeTerm (QuantifiedStack -> Dynamic)
biMapToSBV SymBiMap
m)