{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Grisette.Internal.Core.Data.Class.SafeBitCast
( SafeBitCast (..),
)
where
import Control.Monad.Error.Class (MonadError (throwError))
import Data.Int (Int16, Int32, Int64)
import Data.SBV (Word32)
import Data.Word (Word16, Word64)
import GHC.TypeLits (KnownNat, type (+), type (<=))
import Grisette.Internal.Core.Control.Monad.Class.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.BitCast
( BitCast (bitCast),
BitCastOr,
bitCastOrCanonical,
)
import Grisette.Internal.Core.Data.Class.IEEEFP (fpIsNaN)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.SymIEEEFP
( SymIEEEFPTraits (symFpIsNaN),
)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, tryMerge)
import Grisette.Internal.SymPrim.BV (IntN, WordN, WordN16, WordN32, WordN64)
import Grisette.Internal.SymPrim.FP
( FP,
FP16,
FP32,
FP64,
NotRepresentableFPError (NaNError),
ValidFP,
)
import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN)
import Grisette.Internal.SymPrim.SymFP (SymFP)
class
(MonadError e m, TryMerge m, Mergeable b, BitCastOr a b) =>
SafeBitCast e a b m
where
safeBitCast :: a -> m b
instance
( ValidFP eb sb,
r ~ (eb + sb),
KnownNat r,
1 <= r,
TryMerge m,
MonadError NotRepresentableFPError m
) =>
SafeBitCast NotRepresentableFPError (FP eb sb) (WordN r) m
where
safeBitCast :: FP eb sb -> m (WordN r)
safeBitCast FP eb sb
a
| FP eb sb -> Bool
forall a. RealFloat a => a -> Bool
fpIsNaN FP eb sb
a = m (WordN r) -> m (WordN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (WordN r) -> m (WordN r)) -> m (WordN r) -> m (WordN r)
forall a b. (a -> b) -> a -> b
$ NotRepresentableFPError -> m (WordN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError
| Bool
otherwise = m (WordN r) -> m (WordN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (WordN r) -> m (WordN r)) -> m (WordN r) -> m (WordN r)
forall a b. (a -> b) -> a -> b
$ WordN r -> m (WordN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (WordN r -> m (WordN r)) -> WordN r -> m (WordN r)
forall a b. (a -> b) -> a -> b
$ FP eb sb -> WordN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical FP eb sb
a
instance
( ValidFP eb sb,
r ~ (eb + sb),
KnownNat r,
1 <= r,
TryMerge m,
MonadError NotRepresentableFPError m
) =>
SafeBitCast NotRepresentableFPError (FP eb sb) (IntN r) m
where
safeBitCast :: FP eb sb -> m (IntN r)
safeBitCast FP eb sb
a
| FP eb sb -> Bool
forall a. RealFloat a => a -> Bool
fpIsNaN FP eb sb
a = m (IntN r) -> m (IntN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (IntN r) -> m (IntN r)) -> m (IntN r) -> m (IntN r)
forall a b. (a -> b) -> a -> b
$ NotRepresentableFPError -> m (IntN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError
| Bool
otherwise = m (IntN r) -> m (IntN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (IntN r) -> m (IntN r)) -> m (IntN r) -> m (IntN r)
forall a b. (a -> b) -> a -> b
$ IntN r -> m (IntN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (IntN r -> m (IntN r)) -> IntN r -> m (IntN r)
forall a b. (a -> b) -> a -> b
$ FP eb sb -> IntN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical FP eb sb
a
#define SAFE_BIT_CAST_VIA_INTERMEDIATE(from, to, intermediate) \
instance \
(MonadError NotRepresentableFPError m, TryMerge m) => \
SafeBitCast NotRepresentableFPError from to m \
where \
safeBitCast a = do \
r :: intermediate <- safeBitCast a; \
tryMerge $ return $ bitCast r
#if 1
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP64, Word64, WordN64)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP64, Int64, WordN64)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP64, Double, WordN64)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP32, Word32, WordN32)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP32, Int32, WordN32)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP32, Float, WordN32)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP16, Word16, WordN16)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP16, Int16, WordN16)
#endif
instance
( ValidFP eb sb,
r ~ (eb + sb),
KnownNat r,
1 <= r,
MonadUnion m,
MonadError NotRepresentableFPError m
) =>
SafeBitCast NotRepresentableFPError (SymFP eb sb) (SymWordN r) m
where
safeBitCast :: SymFP eb sb -> m (SymWordN r)
safeBitCast SymFP eb sb
a =
SymBool -> m (SymWordN r) -> m (SymWordN r) -> m (SymWordN r)
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymFP eb sb -> SymBool
forall a. SymIEEEFPTraits a => a -> SymBool
symFpIsNaN SymFP eb sb
a)
(NotRepresentableFPError -> m (SymWordN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError)
(SymWordN r -> m (SymWordN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN r -> m (SymWordN r)) -> SymWordN r -> m (SymWordN r)
forall a b. (a -> b) -> a -> b
$ SymFP eb sb -> SymWordN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical SymFP eb sb
a)
instance
( ValidFP eb sb,
r ~ (eb + sb),
KnownNat r,
1 <= r,
MonadUnion m,
MonadError NotRepresentableFPError m
) =>
SafeBitCast NotRepresentableFPError (SymFP eb sb) (SymIntN r) m
where
safeBitCast :: SymFP eb sb -> m (SymIntN r)
safeBitCast SymFP eb sb
a =
SymBool -> m (SymIntN r) -> m (SymIntN r) -> m (SymIntN r)
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymFP eb sb -> SymBool
forall a. SymIEEEFPTraits a => a -> SymBool
symFpIsNaN SymFP eb sb
a)
(NotRepresentableFPError -> m (SymIntN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError)
(SymIntN r -> m (SymIntN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN r -> m (SymIntN r)) -> SymIntN r -> m (SymIntN r)
forall a b. (a -> b) -> a -> b
$ SymFP eb sb -> SymIntN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical SymFP eb sb
a)