{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module      :   Grisette.Internal.Unified.Class.UnifiedSolvable
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Unified.Class.UnifiedSolvable
  ( UnifiedSolvable (withBaseSolvable),
    con,
    pattern Con,
    conView,
  )
where

import Data.Type.Bool (If)
import GHC.TypeLits (KnownNat, type (<=))
import Grisette.Internal.Core.Data.Class.Solvable (Solvable)
import qualified Grisette.Internal.Core.Data.Class.Solvable as Grisette
import Grisette.Internal.SymPrim.AlgReal (AlgReal)
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.FP (FP, ValidFP)
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal)
import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN)
import Grisette.Internal.SymPrim.SymBool (SymBool)
import Grisette.Internal.SymPrim.SymFP (SymFP)
import Grisette.Internal.SymPrim.SymInteger (SymInteger)
import Grisette.Internal.Unified.EvalModeTag (EvalModeTag (C, S), IsConMode)
import Grisette.Internal.Unified.Util (DecideEvalMode, withMode)

-- $setup
-- >>> import Grisette.Core (ssym)

-- | Wrap a concrete value in a symbolic value.
--
-- >>> con True :: Bool
-- True
--
-- >>> con True :: SymBool
-- true
con ::
  forall mode a con. (DecideEvalMode mode, UnifiedSolvable mode a con) => con -> a
con :: forall (mode :: EvalModeTag) a con.
(DecideEvalMode mode, UnifiedSolvable mode a con) =>
con -> a
con con
v =
  forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode
    (forall (mode :: EvalModeTag) a con r.
UnifiedSolvable mode a con =>
(If (IsConMode mode) (a ~ con) (Solvable con a) => r) -> r
withBaseSolvable @mode @a @con a
con
If (IsConMode mode) (a ~ con) (Solvable con a) => a
v)
    (forall (mode :: EvalModeTag) a con r.
UnifiedSolvable mode a con =>
(If (IsConMode mode) (a ~ con) (Solvable con a) => r) -> r
withBaseSolvable @mode @a @con ((If (IsConMode mode) (a ~ con) (Solvable con a) => a) -> a)
-> (If (IsConMode mode) (a ~ con) (Solvable con a) => a) -> a
forall a b. (a -> b) -> a -> b
$ con -> a
forall c t. Solvable c t => c -> t
Grisette.con con
v)

-- | Extract the concrete value from a symbolic value.
--
-- >>> conView (con True :: SymBool)
-- Just True
--
-- >>> conView (ssym "a" :: SymBool)
-- Nothing
--
-- >>> conView True
-- Just True
conView ::
  forall mode a con.
  (DecideEvalMode mode, UnifiedSolvable mode a con) =>
  a ->
  Maybe con
conView :: forall (mode :: EvalModeTag) a con.
(DecideEvalMode mode, UnifiedSolvable mode a con) =>
a -> Maybe con
conView a
v =
  forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode
    (forall (mode :: EvalModeTag) a con r.
UnifiedSolvable mode a con =>
(If (IsConMode mode) (a ~ con) (Solvable con a) => r) -> r
withBaseSolvable @mode @a @con ((If (IsConMode mode) (a ~ con) (Solvable con a) => Maybe con)
 -> Maybe con)
-> (If (IsConMode mode) (a ~ con) (Solvable con a) => Maybe con)
-> Maybe con
forall a b. (a -> b) -> a -> b
$ con -> Maybe con
forall a. a -> Maybe a
Just a
con
v)
    (forall (mode :: EvalModeTag) a con r.
UnifiedSolvable mode a con =>
(If (IsConMode mode) (a ~ con) (Solvable con a) => r) -> r
withBaseSolvable @mode @a @con ((If (IsConMode mode) (a ~ con) (Solvable con a) => Maybe con)
 -> Maybe con)
-> (If (IsConMode mode) (a ~ con) (Solvable con a) => Maybe con)
-> Maybe con
forall a b. (a -> b) -> a -> b
$ a -> Maybe con
forall c t. Solvable c t => t -> Maybe c
Grisette.conView a
v)

-- | A pattern synonym for extracting the concrete value from a symbolic value.
--
-- >>> case con True :: SymBool of Con v -> v
-- True
--
-- >>> case ssym "a" :: SymBool of Con v -> Just v; _ -> Nothing
-- Nothing
pattern Con :: (DecideEvalMode mode, UnifiedSolvable mode a con) => con -> a
pattern $mCon :: forall {r} {mode :: EvalModeTag} {a} {con}.
(DecideEvalMode mode, UnifiedSolvable mode a con) =>
a -> (con -> r) -> ((# #) -> r) -> r
$bCon :: forall (mode :: EvalModeTag) a con.
(DecideEvalMode mode, UnifiedSolvable mode a con) =>
con -> a
Con v <-
  (conView -> Just v)
  where
    Con con
v = con -> a
forall (mode :: EvalModeTag) a con.
(DecideEvalMode mode, UnifiedSolvable mode a con) =>
con -> a
con con
v

-- | A class that provides the ability to extract/wrap the concrete value
-- from/into a symbolic value.
class UnifiedSolvable mode a con | a -> mode con, con mode -> a where
  withBaseSolvable ::
    ((If (IsConMode mode) (a ~ con) (Solvable con a)) => r) -> r

instance UnifiedSolvable 'C Bool Bool where
  withBaseSolvable :: forall r.
(If (IsConMode 'C) (Bool ~ Bool) (Solvable Bool Bool) => r) -> r
withBaseSolvable If (IsConMode 'C) (Bool ~ Bool) (Solvable Bool Bool) => r
r = r
If (IsConMode 'C) (Bool ~ Bool) (Solvable Bool Bool) => r
r

instance UnifiedSolvable 'S SymBool Bool where
  withBaseSolvable :: forall r.
(If (IsConMode 'S) (SymBool ~ Bool) (Solvable Bool SymBool) => r)
-> r
withBaseSolvable If (IsConMode 'S) (SymBool ~ Bool) (Solvable Bool SymBool) => r
r = r
If (IsConMode 'S) (SymBool ~ Bool) (Solvable Bool SymBool) => r
r

instance UnifiedSolvable 'C Integer Integer where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'C) (Integer ~ Integer) (Solvable Integer Integer) =>
 r)
-> r
withBaseSolvable If (IsConMode 'C) (Integer ~ Integer) (Solvable Integer Integer) =>
r
r = r
If (IsConMode 'C) (Integer ~ Integer) (Solvable Integer Integer) =>
r
r

instance UnifiedSolvable 'S SymInteger Integer where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'S)
   (SymInteger ~ Integer)
   (Solvable Integer SymInteger) =>
 r)
-> r
withBaseSolvable If
  (IsConMode 'S)
  (SymInteger ~ Integer)
  (Solvable Integer SymInteger) =>
r
r = r
If
  (IsConMode 'S)
  (SymInteger ~ Integer)
  (Solvable Integer SymInteger) =>
r
r

instance UnifiedSolvable 'C AlgReal AlgReal where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'C) (AlgReal ~ AlgReal) (Solvable AlgReal AlgReal) =>
 r)
-> r
withBaseSolvable If (IsConMode 'C) (AlgReal ~ AlgReal) (Solvable AlgReal AlgReal) =>
r
r = r
If (IsConMode 'C) (AlgReal ~ AlgReal) (Solvable AlgReal AlgReal) =>
r
r

instance UnifiedSolvable 'S SymAlgReal AlgReal where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'S)
   (SymAlgReal ~ AlgReal)
   (Solvable AlgReal SymAlgReal) =>
 r)
-> r
withBaseSolvable If
  (IsConMode 'S)
  (SymAlgReal ~ AlgReal)
  (Solvable AlgReal SymAlgReal) =>
r
r = r
If
  (IsConMode 'S)
  (SymAlgReal ~ AlgReal)
  (Solvable AlgReal SymAlgReal) =>
r
r

instance (KnownNat n, 1 <= n) => UnifiedSolvable 'C (WordN n) (WordN n) where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'C)
   (WordN n ~ WordN n)
   (Solvable (WordN n) (WordN n)) =>
 r)
-> r
withBaseSolvable If
  (IsConMode 'C)
  (WordN n ~ WordN n)
  (Solvable (WordN n) (WordN n)) =>
r
r = r
If
  (IsConMode 'C)
  (WordN n ~ WordN n)
  (Solvable (WordN n) (WordN n)) =>
r
r

instance (KnownNat n, 1 <= n) => UnifiedSolvable 'S (SymWordN n) (WordN n) where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'S)
   (SymWordN n ~ WordN n)
   (Solvable (WordN n) (SymWordN n)) =>
 r)
-> r
withBaseSolvable If
  (IsConMode 'S)
  (SymWordN n ~ WordN n)
  (Solvable (WordN n) (SymWordN n)) =>
r
r = r
If
  (IsConMode 'S)
  (SymWordN n ~ WordN n)
  (Solvable (WordN n) (SymWordN n)) =>
r
r

instance (KnownNat n, 1 <= n) => UnifiedSolvable 'C (IntN n) (IntN n) where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'C) (IntN n ~ IntN n) (Solvable (IntN n) (IntN n)) =>
 r)
-> r
withBaseSolvable If (IsConMode 'C) (IntN n ~ IntN n) (Solvable (IntN n) (IntN n)) =>
r
r = r
If (IsConMode 'C) (IntN n ~ IntN n) (Solvable (IntN n) (IntN n)) =>
r
r

instance (KnownNat n, 1 <= n) => UnifiedSolvable 'S (SymIntN n) (IntN n) where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'S)
   (SymIntN n ~ IntN n)
   (Solvable (IntN n) (SymIntN n)) =>
 r)
-> r
withBaseSolvable If
  (IsConMode 'S)
  (SymIntN n ~ IntN n)
  (Solvable (IntN n) (SymIntN n)) =>
r
r = r
If
  (IsConMode 'S)
  (SymIntN n ~ IntN n)
  (Solvable (IntN n) (SymIntN n)) =>
r
r

instance (ValidFP eb sb) => UnifiedSolvable 'C (FP eb sb) (FP eb sb) where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'C)
   (FP eb sb ~ FP eb sb)
   (Solvable (FP eb sb) (FP eb sb)) =>
 r)
-> r
withBaseSolvable If
  (IsConMode 'C)
  (FP eb sb ~ FP eb sb)
  (Solvable (FP eb sb) (FP eb sb)) =>
r
r = r
If
  (IsConMode 'C)
  (FP eb sb ~ FP eb sb)
  (Solvable (FP eb sb) (FP eb sb)) =>
r
r

instance (ValidFP eb sb) => UnifiedSolvable 'S (SymFP eb sb) (FP eb sb) where
  withBaseSolvable :: forall r.
(If
   (IsConMode 'S)
   (SymFP eb sb ~ FP eb sb)
   (Solvable (FP eb sb) (SymFP eb sb)) =>
 r)
-> r
withBaseSolvable If
  (IsConMode 'S)
  (SymFP eb sb ~ FP eb sb)
  (Solvable (FP eb sb) (SymFP eb sb)) =>
r
r = r
If
  (IsConMode 'S)
  (SymFP eb sb ~ FP eb sb)
  (Solvable (FP eb sb) (SymFP eb sb)) =>
r
r