{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# HLINT ignore "Eta reduce" #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Avoid lambda" #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.TabularFun
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.TabularFun
  ( type (=->) (..),
  )
where

import Control.DeepSeq (NFData, NFData1)
import Data.Bifunctor (Bifunctor (second))
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Hashable (Hashable)
import qualified Data.SBV as SBV
import qualified Data.SBV.Dynamic as SBVD
import qualified Data.Serialize as Cereal
import GHC.Generics (Generic, Generic1)
import Grisette.Internal.Core.Data.Class.Function
  ( Apply (FunType, apply),
    Function ((#)),
  )
import Grisette.Internal.SymPrim.FunInstanceGen (supportedPrimFunUpTo)
import Grisette.Internal.SymPrim.Prim.Internal.PartialEval (totalize2)
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( NonFuncPrimConstraint,
    NonFuncSBVRep (NonFuncSBVBaseType),
    PEvalApplyTerm (pevalApplyTerm, sbvApplyTerm),
    SBVRep (SBVType),
    SupportedNonFuncPrim (conNonFuncSBVTerm, withNonFuncPrim),
    SupportedPrim
      ( conSBVTerm,
        defaultValue,
        parseSMTModelResult,
        pevalITETerm,
        withPrim
      ),
    SupportedPrimConstraint (PrimConstraint),
    Term,
    applyTerm,
    conTerm,
    partitionCVArg,
    pevalEqTerm,
    pevalITEBasicTerm,
    pattern ConTerm,
  )
import Language.Haskell.TH.Syntax (Lift)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim

-- |
-- Functions as a table. Use the `#` operator to apply the function.
--
-- >>> let f = TabularFun [(1, 2), (3, 4)] 0 :: Int =-> Int
-- >>> f # 1
-- 2
-- >>> f # 2
-- 0
-- >>> f # 3
-- 4
data (=->) a b = TabularFun {forall a b. (a =-> b) -> [(a, b)]
funcTable :: [(a, b)], forall a b. (a =-> b) -> b
defaultFuncValue :: b}
  deriving (Int -> (a =-> b) -> ShowS
[a =-> b] -> ShowS
(a =-> b) -> String
(Int -> (a =-> b) -> ShowS)
-> ((a =-> b) -> String) -> ([a =-> b] -> ShowS) -> Show (a =-> b)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a b. (Show a, Show b) => Int -> (a =-> b) -> ShowS
forall a b. (Show a, Show b) => [a =-> b] -> ShowS
forall a b. (Show a, Show b) => (a =-> b) -> String
$cshowsPrec :: forall a b. (Show a, Show b) => Int -> (a =-> b) -> ShowS
showsPrec :: Int -> (a =-> b) -> ShowS
$cshow :: forall a b. (Show a, Show b) => (a =-> b) -> String
show :: (a =-> b) -> String
$cshowList :: forall a b. (Show a, Show b) => [a =-> b] -> ShowS
showList :: [a =-> b] -> ShowS
Show, (a =-> b) -> (a =-> b) -> Bool
((a =-> b) -> (a =-> b) -> Bool)
-> ((a =-> b) -> (a =-> b) -> Bool) -> Eq (a =-> b)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall a b. (Eq a, Eq b) => (a =-> b) -> (a =-> b) -> Bool
$c== :: forall a b. (Eq a, Eq b) => (a =-> b) -> (a =-> b) -> Bool
== :: (a =-> b) -> (a =-> b) -> Bool
$c/= :: forall a b. (Eq a, Eq b) => (a =-> b) -> (a =-> b) -> Bool
/= :: (a =-> b) -> (a =-> b) -> Bool
Eq, (forall x. (a =-> b) -> Rep (a =-> b) x)
-> (forall x. Rep (a =-> b) x -> a =-> b) -> Generic (a =-> b)
forall x. Rep (a =-> b) x -> a =-> b
forall x. (a =-> b) -> Rep (a =-> b) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a b x. Rep (a =-> b) x -> a =-> b
forall a b x. (a =-> b) -> Rep (a =-> b) x
$cfrom :: forall a b x. (a =-> b) -> Rep (a =-> b) x
from :: forall x. (a =-> b) -> Rep (a =-> b) x
$cto :: forall a b x. Rep (a =-> b) x -> a =-> b
to :: forall x. Rep (a =-> b) x -> a =-> b
Generic, (forall a. (a =-> a) -> Rep1 ((=->) a) a)
-> (forall a. Rep1 ((=->) a) a -> a =-> a) -> Generic1 ((=->) a)
forall a. Rep1 ((=->) a) a -> a =-> a
forall a. (a =-> a) -> Rep1 ((=->) a) a
forall a a. Rep1 ((=->) a) a -> a =-> a
forall a a. (a =-> a) -> Rep1 ((=->) a) a
forall k (f :: k -> *).
(forall (a :: k). f a -> Rep1 f a)
-> (forall (a :: k). Rep1 f a -> f a) -> Generic1 f
$cfrom1 :: forall a a. (a =-> a) -> Rep1 ((=->) a) a
from1 :: forall a. (a =-> a) -> Rep1 ((=->) a) a
$cto1 :: forall a a. Rep1 ((=->) a) a -> a =-> a
to1 :: forall a. Rep1 ((=->) a) a -> a =-> a
Generic1, (forall (m :: * -> *). Quote m => (a =-> b) -> m Exp)
-> (forall (m :: * -> *). Quote m => (a =-> b) -> Code m (a =-> b))
-> Lift (a =-> b)
forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> m Exp
forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> Code m (a =-> b)
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => (a =-> b) -> m Exp
forall (m :: * -> *). Quote m => (a =-> b) -> Code m (a =-> b)
$clift :: forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> m Exp
lift :: forall (m :: * -> *). Quote m => (a =-> b) -> m Exp
$cliftTyped :: forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> Code m (a =-> b)
liftTyped :: forall (m :: * -> *). Quote m => (a =-> b) -> Code m (a =-> b)
Lift, (a =-> b) -> ()
((a =-> b) -> ()) -> NFData (a =-> b)
forall a. (a -> ()) -> NFData a
forall a b. (NFData a, NFData b) => (a =-> b) -> ()
$crnf :: forall a b. (NFData a, NFData b) => (a =-> b) -> ()
rnf :: (a =-> b) -> ()
NFData, (forall a. NFData a => NFData (a =-> a)) =>
(forall a. (a -> ()) -> (a =-> a) -> ()) -> NFData1 ((=->) a)
forall a. NFData a => NFData (a =-> a)
forall a a. (NFData a, NFData a) => NFData (a =-> a)
forall a a. NFData a => (a -> ()) -> (a =-> a) -> ()
forall a. (a -> ()) -> (a =-> a) -> ()
forall (f :: * -> *).
(forall a. NFData a => NFData (f a)) =>
(forall a. (a -> ()) -> f a -> ()) -> NFData1 f
$cliftRnf :: forall a a. NFData a => (a -> ()) -> (a =-> a) -> ()
liftRnf :: forall a. (a -> ()) -> (a =-> a) -> ()
NFData1, (forall (m :: * -> *). MonadPut m => (a =-> b) -> m ())
-> (forall (m :: * -> *). MonadGet m => m (a =-> b))
-> Serial (a =-> b)
forall a.
(forall (m :: * -> *). MonadPut m => a -> m ())
-> (forall (m :: * -> *). MonadGet m => m a) -> Serial a
forall a b (m :: * -> *).
(Serial a, Serial b, MonadGet m) =>
m (a =-> b)
forall a b (m :: * -> *).
(Serial a, Serial b, MonadPut m) =>
(a =-> b) -> m ()
forall (m :: * -> *). MonadGet m => m (a =-> b)
forall (m :: * -> *). MonadPut m => (a =-> b) -> m ()
$cserialize :: forall a b (m :: * -> *).
(Serial a, Serial b, MonadPut m) =>
(a =-> b) -> m ()
serialize :: forall (m :: * -> *). MonadPut m => (a =-> b) -> m ()
$cdeserialize :: forall a b (m :: * -> *).
(Serial a, Serial b, MonadGet m) =>
m (a =-> b)
deserialize :: forall (m :: * -> *). MonadGet m => m (a =-> b)
Serial)

instance (Serial a, Serial b) => Cereal.Serialize (a =-> b) where
  put :: Putter (a =-> b)
put = Putter (a =-> b)
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => (a =-> b) -> m ()
serialize
  get :: Get (a =-> b)
get = Get (a =-> b)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (a =-> b)
deserialize

instance (Serial a, Serial b) => Binary.Binary (a =-> b) where
  put :: (a =-> b) -> Put
put = (a =-> b) -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => (a =-> b) -> m ()
serialize
  get :: Get (a =-> b)
get = Get (a =-> b)
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m (a =-> b)
deserialize

infixr 0 =->

instance (Eq a) => Function (a =-> b) a b where
  (TabularFun [(a, b)]
table b
d) # :: (a =-> b) -> a -> b
# a
a = [(a, b)] -> b
go [(a, b)]
table
    where
      go :: [(a, b)] -> b
go [] = b
d
      go ((a
av, b
bv) : [(a, b)]
s)
        | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
av = b
bv
        | Bool
otherwise = [(a, b)] -> b
go [(a, b)]
s

instance (Hashable a, Hashable b) => Hashable (a =-> b)

instance
  (SupportedNonFuncPrim a, SupportedPrim b) =>
  SupportedPrimConstraint (a =-> b)
  where
  type
    PrimConstraint (a =-> b) =
      ( SupportedNonFuncPrim a,
        SupportedPrim b,
        NonFuncPrimConstraint a,
        PrimConstraint b
      )

instance (SupportedNonFuncPrim a, SupportedPrim b) => SBVRep (a =-> b) where
  type SBVType (a =-> b) = SBV.SBV (NonFuncSBVBaseType a) -> SBVType b

instance
  (SupportedPrim a, SupportedPrim b, Eq a, SupportedPrim (a =-> b)) =>
  PEvalApplyTerm (a =-> b) a b
  where
  pevalApplyTerm :: Term (a =-> b) -> Term a -> Term b
pevalApplyTerm = (Term (a =-> b) -> PartialFun (Term a) (Term b))
-> (Term (a =-> b) -> Term a -> Term b)
-> Term (a =-> b)
-> Term a
-> Term b
forall a b c. (a -> PartialFun b c) -> (a -> b -> c) -> a -> b -> c
totalize2 Term (a =-> b) -> PartialFun (Term a) (Term b)
(SupportedPrim a, SupportedPrim b) =>
Term (a =-> b) -> PartialFun (Term a) (Term b)
doPevalApplyTerm Term (a =-> b) -> Term a -> Term b
forall f a b.
(PEvalApplyTerm f a b, SupportedPrim b) =>
Term f -> Term a -> Term b
applyTerm
    where
      doPevalApplyTerm ::
        (SupportedPrim a, SupportedPrim b) =>
        Term (a =-> b) ->
        Term a ->
        Maybe (Term b)
      doPevalApplyTerm :: (SupportedPrim a, SupportedPrim b) =>
Term (a =-> b) -> PartialFun (Term a) (Term b)
doPevalApplyTerm (ConTerm a =-> b
f) (ConTerm a
a) = Term b -> Maybe (Term b)
forall a. a -> Maybe a
Just (Term b -> Maybe (Term b)) -> Term b -> Maybe (Term b)
forall a b. (a -> b) -> a -> b
$ b -> Term b
forall t. SupportedPrim t => t -> Term t
conTerm (b -> Term b) -> b -> Term b
forall a b. (a -> b) -> a -> b
$ a =-> b
f (a =-> b) -> a -> b
forall f arg ret. Function f arg ret => f -> arg -> ret
# a
a
      doPevalApplyTerm (ConTerm (TabularFun [(a, b)]
f b
d)) Term a
a = Term b -> Maybe (Term b)
forall a. a -> Maybe a
Just (Term b -> Maybe (Term b)) -> Term b -> Maybe (Term b)
forall a b. (a -> b) -> a -> b
$ [(a, b)] -> Term b
go [(a, b)]
f
        where
          go :: [(a, b)] -> Term b
go [] = b -> Term b
forall t. SupportedPrim t => t -> Term t
conTerm b
d
          go ((a
x, b
y) : [(a, b)]
xs) =
            Term Bool -> Term b -> Term b -> Term b
forall t.
SupportedPrim t =>
Term Bool -> Term t -> Term t -> Term t
pevalITETerm (Term a -> Term a -> Term Bool
forall t. SupportedPrim t => Term t -> Term t -> Term Bool
pevalEqTerm Term a
a (a -> Term a
forall t. SupportedPrim t => t -> Term t
conTerm a
x)) (b -> Term b
forall t. SupportedPrim t => t -> Term t
conTerm b
y) ([(a, b)] -> Term b
go [(a, b)]
xs)
      doPevalApplyTerm Term (a =-> b)
_ Term a
_ = Maybe (Term b)
forall a. Maybe a
Nothing
  sbvApplyTerm :: SBVType (a =-> b) -> SBVType a -> SBVType b
sbvApplyTerm SBVType (a =-> b)
f SBVType a
a =
    forall t a.
SupportedPrim t =>
((PrimConstraint t, SMTDefinable (SBVType t),
  Mergeable (SBVType t), Typeable (SBVType t)) =>
 a)
-> a
withPrim @(a =-> b) (((PrimConstraint (a =-> b), SMTDefinable (SBVType (a =-> b)),
   Mergeable (SBVType (a =-> b)), Typeable (SBVType (a =-> b))) =>
  SBVType b)
 -> SBVType b)
-> ((PrimConstraint (a =-> b), SMTDefinable (SBVType (a =-> b)),
     Mergeable (SBVType (a =-> b)), Typeable (SBVType (a =-> b))) =>
    SBVType b)
-> SBVType b
forall a b. (a -> b) -> a -> b
$ forall a r.
SupportedNonFuncPrim a =>
(NonFuncPrimConstraint a => r) -> r
withNonFuncPrim @a (((SymVal (NonFuncSBVBaseType a), EqSymbolic (SBVType a),
   Mergeable (SBVType a), SMTDefinable (SBVType a),
   Mergeable (SBVType a), SBVType a ~ SBV (NonFuncSBVBaseType a),
   PrimConstraint a) =>
  SBVType b)
 -> SBVType b)
-> ((SymVal (NonFuncSBVBaseType a), EqSymbolic (SBVType a),
     Mergeable (SBVType a), SMTDefinable (SBVType a),
     Mergeable (SBVType a), SBVType a ~ SBV (NonFuncSBVBaseType a),
     PrimConstraint a) =>
    SBVType b)
-> SBVType b
forall a b. (a -> b) -> a -> b
$ SBVType (a =-> b)
SBV (NonFuncSBVBaseType a) -> SBVType b
f SBV (NonFuncSBVBaseType a)
SBVType a
a

instance (Apply t, Eq a) => Apply (a =-> t) where
  type FunType (a =-> t) = a -> FunType t
  apply :: (a =-> t) -> FunType (a =-> t)
apply a =-> t
uf a
a = t -> FunType t
forall uf. Apply uf => uf -> FunType uf
apply (a =-> t
uf (a =-> t) -> a -> t
forall f arg ret. Function f arg ret => f -> arg -> ret
# a
a)

lowerTFunCon ::
  forall a b.
  ( SupportedNonFuncPrim a,
    SupportedPrim b,
    SBV.Mergeable (SBVType b)
  ) =>
  (a =-> b) ->
  ( SBV.SBV (NonFuncSBVBaseType a) ->
    SBVType b
  )
lowerTFunCon :: forall a b.
(SupportedNonFuncPrim a, SupportedPrim b, Mergeable (SBVType b)) =>
(a =-> b) -> SBV (NonFuncSBVBaseType a) -> SBVType b
lowerTFunCon (TabularFun [(a, b)]
l b
d) = forall a r.
SupportedNonFuncPrim a =>
(NonFuncPrimConstraint a => r) -> r
withNonFuncPrim @a ((NonFuncPrimConstraint a =>
  SBV (NonFuncSBVBaseType a) -> SBVType b)
 -> SBV (NonFuncSBVBaseType a) -> SBVType b)
-> (NonFuncPrimConstraint a =>
    SBV (NonFuncSBVBaseType a) -> SBVType b)
-> SBV (NonFuncSBVBaseType a)
-> SBVType b
forall a b. (a -> b) -> a -> b
$ [(a, b)] -> b -> SBV (NonFuncSBVBaseType a) -> SBVType b
forall {t} {t} {a}.
(SBVType t ~ SBVType t, Mergeable (SBVType t),
 SymVal (NonFuncSBVBaseType a), SupportedNonFuncPrim a,
 SupportedPrim t, SupportedPrim t) =>
[(a, t)] -> t -> SBV (NonFuncSBVBaseType a) -> SBVType t
go [(a, b)]
l b
d
  where
    go :: [(a, t)] -> t -> SBV (NonFuncSBVBaseType a) -> SBVType t
go [] t
d SBV (NonFuncSBVBaseType a)
_ = t -> SBVType t
forall t. SupportedPrim t => t -> SBVType t
conSBVTerm t
d
    go ((a
x, t
r) : [(a, t)]
xs) t
d SBV (NonFuncSBVBaseType a)
v =
      SBool -> SBVType t -> SBVType t -> SBVType t
forall a. Mergeable a => SBool -> a -> a -> a
SBV.ite (a -> SBV (NonFuncSBVBaseType a)
forall a. SupportedNonFuncPrim a => a -> SBV (NonFuncSBVBaseType a)
conNonFuncSBVTerm a
x SBV (NonFuncSBVBaseType a) -> SBV (NonFuncSBVBaseType a) -> SBool
forall a. EqSymbolic a => a -> a -> SBool
SBV..== SBV (NonFuncSBVBaseType a)
v) (t -> SBVType t
forall t. SupportedPrim t => t -> SBVType t
conSBVTerm t
r) ([(a, t)] -> t -> SBV (NonFuncSBVBaseType a) -> SBVType t
go [(a, t)]
xs t
d SBV (NonFuncSBVBaseType a)
v)

parseTabularFunSMTModelResult ::
  forall a b.
  (SupportedNonFuncPrim a, SupportedPrim b) =>
  Int ->
  ([([SBVD.CV], SBVD.CV)], SBVD.CV) ->
  a =-> b
parseTabularFunSMTModelResult :: forall a b.
(SupportedNonFuncPrim a, SupportedPrim b) =>
Int -> ([([CV], CV)], CV) -> a =-> b
parseTabularFunSMTModelResult Int
level ([([CV], CV)]
l, CV
s) =
  [(a, b)] -> b -> a =-> b
forall a b. [(a, b)] -> b -> a =-> b
TabularFun
    ( ([([CV], CV)] -> b) -> (a, [([CV], CV)]) -> (a, b)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second
        ( \[([CV], CV)]
r ->
            case [([CV], CV)]
r of
              [([], CV
v)] -> Int -> ([([CV], CV)], CV) -> b
forall t. SupportedPrim t => Int -> ([([CV], CV)], CV) -> t
parseSMTModelResult (Int
level Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([], CV
v)
              [([CV], CV)]
_ -> Int -> ([([CV], CV)], CV) -> b
forall t. SupportedPrim t => Int -> ([([CV], CV)], CV) -> t
parseSMTModelResult (Int
level Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([([CV], CV)]
r, CV
s)
        )
        ((a, [([CV], CV)]) -> (a, b)) -> [(a, [([CV], CV)])] -> [(a, b)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a.
SupportedNonFuncPrim a =>
[([CV], CV)] -> [(a, [([CV], CV)])]
partitionCVArg @a [([CV], CV)]
l
    )
    (Int -> ([([CV], CV)], CV) -> b
forall t. SupportedPrim t => Int -> ([([CV], CV)], CV) -> t
parseSMTModelResult (Int
level Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([], CV
s))

supportedPrimFunUpTo
  [|TabularFun [] defaultValue|]
  [|pevalITEBasicTerm|]
  [|parseTabularFunSMTModelResult|]
  ( \tyVars ->
      [|
        \f ->
          withNonFuncPrim @($(last tyVars)) $
            lowerTFunCon f
        |]
  )
  "TabularFun"
  "tfunc"
  ''(=->)
  8