{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.SymPrim.Prim.Internal.Instances.PEvalBitCastTerm
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.SymPrim.Prim.Internal.Instances.PEvalBitCastTerm
  ( doPevalBitCast,
  )
where

import qualified Data.SBV as SBV
import GHC.TypeLits (KnownNat, type (+), type (<=))
import Grisette.Internal.Core.Data.Class.BitCast
  ( BitCast (bitCast),
    BitCastOr (bitCastOr),
  )
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.FP (FP, ValidFP, withValidFPProofs)
import Grisette.Internal.SymPrim.Prim.Internal.Instances.SupportedPrim ()
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( PEvalBitCastOrTerm (pevalBitCastOrTerm, sbvBitCastOr),
    PEvalBitCastTerm (pevalBitCastTerm, sbvBitCast),
    SupportedPrim,
    Term,
    bitCastOrTerm,
    bitCastTerm,
    conTerm,
    pattern BitCastTerm,
    pattern ConTerm,
    pattern DynTerm,
    pattern SupportedTerm,
  )
import Grisette.Internal.SymPrim.Prim.Internal.Unfold
  ( binaryUnfoldOnce,
    unaryUnfoldOnce,
  )

doPevalBitCastSameType ::
  forall x b. (SupportedPrim b) => Term x -> Maybe (Term b)
doPevalBitCastSameType :: forall x b. SupportedPrim b => Term x -> Maybe (Term b)
doPevalBitCastSameType (BitCastTerm (DynTerm (Term b
b :: Term b))) = Term b -> Maybe (Term b)
forall a. a -> Maybe a
Just Term b
b
doPevalBitCastSameType (BitCastTerm Term a
x) = Term a -> Maybe (Term b)
forall x b. SupportedPrim b => Term x -> Maybe (Term b)
doPevalBitCastSameType Term a
x
doPevalBitCastSameType Term x
_ = Maybe (Term b)
forall a. Maybe a
Nothing

-- | Partially evaluate a bitcast term. If no reduction is performed, return
-- Nothing.
doPevalBitCast :: (PEvalBitCastTerm a b, SupportedPrim b) => Term a -> Maybe (Term b)
doPevalBitCast :: forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Maybe (Term b)
doPevalBitCast (ConTerm a
v) = Term b -> Maybe (Term b)
forall a. a -> Maybe a
Just (Term b -> Maybe (Term b)) -> Term b -> Maybe (Term b)
forall a b. (a -> b) -> a -> b
$ b -> Term b
forall t. SupportedPrim t => t -> Term t
conTerm (b -> Term b) -> b -> Term b
forall a b. (a -> b) -> a -> b
$ a -> b
forall from to. BitCast from to => from -> to
bitCast a
v
doPevalBitCast Term a
t = Term a -> Maybe (Term b)
forall x b. SupportedPrim b => Term x -> Maybe (Term b)
doPevalBitCastSameType Term a
t

pevalBitCastGeneral ::
  forall a b.
  (PEvalBitCastTerm a b, SupportedPrim b) =>
  Term a ->
  Term b
pevalBitCastGeneral :: forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
pevalBitCastGeneral = PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
forall a b.
SupportedPrim b =>
PartialRuleUnary a b -> TotalRuleUnary a b -> TotalRuleUnary a b
unaryUnfoldOnce PartialRuleUnary a b
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Maybe (Term b)
doPevalBitCast TotalRuleUnary a b
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
bitCastTerm

doPevalBitCastOr ::
  (PEvalBitCastOrTerm a b) =>
  Term b ->
  Term a ->
  Maybe (Term b)
doPevalBitCastOr :: forall a b.
PEvalBitCastOrTerm a b =>
Term b -> Term a -> Maybe (Term b)
doPevalBitCastOr (ConTerm b
d) (ConTerm a
v) =
  Term b -> Maybe (Term b)
forall a. a -> Maybe a
Just (Term b -> Maybe (Term b)) -> Term b -> Maybe (Term b)
forall a b. (a -> b) -> a -> b
$ b -> Term b
forall t. SupportedPrim t => t -> Term t
conTerm (b -> Term b) -> b -> Term b
forall a b. (a -> b) -> a -> b
$ b -> a -> b
forall from to. BitCastOr from to => to -> from -> to
bitCastOr b
d a
v
doPevalBitCastOr Term b
_ Term a
_ = Maybe (Term b)
forall a. Maybe a
Nothing

pevalBitCastOr ::
  forall a b.
  (PEvalBitCastOrTerm a b) =>
  Term b ->
  Term a ->
  Term b
pevalBitCastOr :: forall a b. PEvalBitCastOrTerm a b => Term b -> Term a -> Term b
pevalBitCastOr d :: Term b
d@Term b
SupportedTerm =
  PartialRuleBinary b a b
-> TotalRuleBinary b a b -> TotalRuleBinary b a b
forall a b c.
SupportedPrim c =>
PartialRuleBinary a b c
-> TotalRuleBinary a b c -> TotalRuleBinary a b c
binaryUnfoldOnce PartialRuleBinary b a b
forall a b.
PEvalBitCastOrTerm a b =>
Term b -> Term a -> Maybe (Term b)
doPevalBitCastOr TotalRuleBinary b a b
forall a b. PEvalBitCastOrTerm a b => Term b -> Term a -> Term b
bitCastOrTerm Term b
d

instance PEvalBitCastTerm Bool (IntN 1) where
  pevalBitCastTerm :: Term Bool -> Term (IntN 1)
pevalBitCastTerm = Term Bool -> Term (IntN 1)
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
pevalBitCastGeneral
  sbvBitCast :: SBVType Bool -> SBVType (IntN 1)
sbvBitCast SBVType Bool
x = SBool -> SBV (IntN 1) -> SBV (IntN 1) -> SBV (IntN 1)
forall a. Mergeable a => SBool -> a -> a -> a
SBV.ite SBool
SBVType Bool
x (IntN 1 -> SBV (IntN 1)
forall a. SymVal a => a -> SBV a
SBV.literal IntN 1
1) (IntN 1 -> SBV (IntN 1)
forall a. SymVal a => a -> SBV a
SBV.literal IntN 1
0)

instance PEvalBitCastTerm Bool (WordN 1) where
  pevalBitCastTerm :: Term Bool -> Term (WordN 1)
pevalBitCastTerm = Term Bool -> Term (WordN 1)
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
pevalBitCastGeneral
  sbvBitCast :: SBVType Bool -> SBVType (WordN 1)
sbvBitCast SBVType Bool
x = SBool -> SBV (WordN 1) -> SBV (WordN 1) -> SBV (WordN 1)
forall a. Mergeable a => SBool -> a -> a -> a
SBV.ite SBool
SBVType Bool
x (WordN 1 -> SBV (WordN 1)
forall a. SymVal a => a -> SBV a
SBV.literal WordN 1
1) (WordN 1 -> SBV (WordN 1)
forall a. SymVal a => a -> SBV a
SBV.literal WordN 1
0)

instance PEvalBitCastTerm (IntN 1) Bool where
  pevalBitCastTerm :: Term (IntN 1) -> Term Bool
pevalBitCastTerm = Term (IntN 1) -> Term Bool
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
pevalBitCastGeneral
  sbvBitCast :: SBVType (IntN 1) -> SBVType Bool
sbvBitCast SBVType (IntN 1)
x = SBV (IntN 1) -> Int -> SBool
forall a. SFiniteBits a => SBV a -> Int -> SBool
SBV.sTestBit SBV (IntN 1)
SBVType (IntN 1)
x Int
0

instance PEvalBitCastTerm (WordN 1) Bool where
  pevalBitCastTerm :: Term (WordN 1) -> Term Bool
pevalBitCastTerm = Term (WordN 1) -> Term Bool
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
pevalBitCastGeneral
  sbvBitCast :: SBVType (WordN 1) -> SBVType Bool
sbvBitCast SBVType (WordN 1)
x = SBV (WordN 1) -> Int -> SBool
forall a. SFiniteBits a => SBV a -> Int -> SBool
SBV.sTestBit SBV (WordN 1)
SBVType (WordN 1)
x Int
0

instance
  (n ~ (eb + sb), ValidFP eb sb, KnownNat n, 1 <= n) =>
  PEvalBitCastTerm (WordN n) (FP eb sb)
  where
  pevalBitCastTerm :: Term (WordN n) -> Term (FP eb sb)
pevalBitCastTerm = Term (WordN n) -> Term (FP eb sb)
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
pevalBitCastGeneral
  sbvBitCast :: SBVType (WordN n) -> SBVType (FP eb sb)
sbvBitCast = forall (eb :: Natural) (sb :: Natural) r.
ValidFP eb sb =>
((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
  1 <= eb, 1 <= sb) =>
 r)
-> r
withValidFPProofs @eb @sb (((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
   1 <= eb, 1 <= sb) =>
  SBVType (WordN n) -> SBVType (FP eb sb))
 -> SBVType (WordN n) -> SBVType (FP eb sb))
-> ((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
     1 <= eb, 1 <= sb) =>
    SBVType (WordN n) -> SBVType (FP eb sb))
-> SBVType (WordN n)
-> SBVType (FP eb sb)
forall a b. (a -> b) -> a -> b
$ (KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
 1 <= eb, 1 <= sb) =>
SBVType (WordN n) -> SBVType (FP eb sb)
SWord (eb + sb) -> SFloatingPoint eb sb
SBVType (WordN n) -> SBVType (FP eb sb)
forall (eb :: Natural) (sb :: Natural).
(KnownNat (eb + sb), BVIsNonZero (eb + sb), ValidFloat eb sb) =>
SWord (eb + sb) -> SFloatingPoint eb sb
SBV.sWordAsSFloatingPoint

instance
  (n ~ (eb + sb), ValidFP eb sb, KnownNat n, 1 <= n) =>
  PEvalBitCastTerm (IntN n) (FP eb sb)
  where
  pevalBitCastTerm :: Term (IntN n) -> Term (FP eb sb)
pevalBitCastTerm = Term (IntN n) -> Term (FP eb sb)
forall a b.
(PEvalBitCastTerm a b, SupportedPrim b) =>
Term a -> Term b
pevalBitCastGeneral
  sbvBitCast :: SBVType (IntN n) -> SBVType (FP eb sb)
sbvBitCast =
    forall (eb :: Natural) (sb :: Natural) r.
ValidFP eb sb =>
((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
  1 <= eb, 1 <= sb) =>
 r)
-> r
withValidFPProofs @eb @sb (((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
   1 <= eb, 1 <= sb) =>
  SBVType (IntN n) -> SBVType (FP eb sb))
 -> SBVType (IntN n) -> SBVType (FP eb sb))
-> ((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
     1 <= eb, 1 <= sb) =>
    SBVType (IntN n) -> SBVType (FP eb sb))
-> SBVType (IntN n)
-> SBVType (FP eb sb)
forall a b. (a -> b) -> a -> b
$ SBV (WordN n) -> SFloatingPoint eb sb
SWord (eb + sb) -> SFloatingPoint eb sb
forall (eb :: Natural) (sb :: Natural).
(KnownNat (eb + sb), BVIsNonZero (eb + sb), ValidFloat eb sb) =>
SWord (eb + sb) -> SFloatingPoint eb sb
SBV.sWordAsSFloatingPoint (SBV (WordN n) -> SFloatingPoint eb sb)
-> (SBV (IntN n) -> SBV (WordN n))
-> SBV (IntN n)
-> SFloatingPoint eb sb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SBV (IntN n) -> SBV (WordN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral

instance
  (n ~ (eb + sb), ValidFP eb sb, KnownNat n, 1 <= n) =>
  PEvalBitCastOrTerm (FP eb sb) (WordN n)
  where
  pevalBitCastOrTerm :: Term (WordN n) -> Term (FP eb sb) -> Term (WordN n)
pevalBitCastOrTerm = Term (WordN n) -> Term (FP eb sb) -> Term (WordN n)
forall a b. PEvalBitCastOrTerm a b => Term b -> Term a -> Term b
pevalBitCastOr
  sbvBitCastOr :: SBVType (WordN n) -> SBVType (FP eb sb) -> SBVType (WordN n)
sbvBitCastOr SBVType (WordN n)
d SBVType (FP eb sb)
v =
    forall (eb :: Natural) (sb :: Natural) r.
ValidFP eb sb =>
((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
  1 <= eb, 1 <= sb) =>
 r)
-> r
withValidFPProofs @eb @sb (((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
   1 <= eb, 1 <= sb) =>
  SBVType (WordN n))
 -> SBVType (WordN n))
-> ((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
     1 <= eb, 1 <= sb) =>
    SBVType (WordN n))
-> SBVType (WordN n)
forall a b. (a -> b) -> a -> b
$
      SBool -> SBV (WordN n) -> SBV (WordN n) -> SBV (WordN n)
forall a. Mergeable a => SBool -> a -> a -> a
SBV.ite
        (SBV (FloatingPoint eb sb) -> SBool
forall a. IEEEFloating a => SBV a -> SBool
SBV.fpIsNaN SBV (FloatingPoint eb sb)
SBVType (FP eb sb)
v)
        SBV (WordN n)
SBVType (WordN n)
d
        (SBV (FloatingPoint eb sb) -> SWord (eb + sb)
forall (eb :: Natural) (sb :: Natural).
(ValidFloat eb sb, KnownNat (eb + sb), BVIsNonZero (eb + sb)) =>
SFloatingPoint eb sb -> SWord (eb + sb)
SBV.sFloatingPointAsSWord SBV (FloatingPoint eb sb)
SBVType (FP eb sb)
v)

instance
  (n ~ (eb + sb), ValidFP eb sb, KnownNat n, 1 <= n) =>
  PEvalBitCastOrTerm (FP eb sb) (IntN n)
  where
  pevalBitCastOrTerm :: Term (IntN n) -> Term (FP eb sb) -> Term (IntN n)
pevalBitCastOrTerm = Term (IntN n) -> Term (FP eb sb) -> Term (IntN n)
forall a b. PEvalBitCastOrTerm a b => Term b -> Term a -> Term b
pevalBitCastOr
  sbvBitCastOr :: SBVType (IntN n) -> SBVType (FP eb sb) -> SBVType (IntN n)
sbvBitCastOr SBVType (IntN n)
d SBVType (FP eb sb)
v =
    forall (eb :: Natural) (sb :: Natural) r.
ValidFP eb sb =>
((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
  1 <= eb, 1 <= sb) =>
 r)
-> r
withValidFPProofs @eb @sb (((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
   1 <= eb, 1 <= sb) =>
  SBVType (IntN n))
 -> SBVType (IntN n))
-> ((KnownNat (eb + sb), BVIsNonZero (eb + sb), 1 <= (eb + sb),
     1 <= eb, 1 <= sb) =>
    SBVType (IntN n))
-> SBVType (IntN n)
forall a b. (a -> b) -> a -> b
$
      SBool -> SBV (IntN n) -> SBV (IntN n) -> SBV (IntN n)
forall a. Mergeable a => SBool -> a -> a -> a
SBV.ite
        (SBV (FloatingPoint eb sb) -> SBool
forall a. IEEEFloating a => SBV a -> SBool
SBV.fpIsNaN SBV (FloatingPoint eb sb)
SBVType (FP eb sb)
v)
        SBV (IntN n)
SBVType (IntN n)
d
        (SBV (WordN (eb + sb)) -> SBV (IntN n)
forall a b.
(Integral a, HasKind a, Num a, SymVal a, HasKind b, Num b,
 SymVal b) =>
SBV a -> SBV b
SBV.sFromIntegral (SBV (WordN (eb + sb)) -> SBV (IntN n))
-> SBV (WordN (eb + sb)) -> SBV (IntN n)
forall a b. (a -> b) -> a -> b
$ SBV (FloatingPoint eb sb) -> SBV (WordN (eb + sb))
forall (eb :: Natural) (sb :: Natural).
(ValidFloat eb sb, KnownNat (eb + sb), BVIsNonZero (eb + sb)) =>
SFloatingPoint eb sb -> SWord (eb + sb)
SBV.sFloatingPointAsSWord SBV (FloatingPoint eb sb)
SBVType (FP eb sb)
v)