{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.SafeBitCast
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SafeBitCast
  ( SafeBitCast (..),
  )
where

import Control.Monad.Error.Class (MonadError (throwError))
import Data.Int (Int16, Int32, Int64)
import Data.SBV (Word32)
import Data.Word (Word16, Word64)
import GHC.TypeLits (KnownNat, type (+), type (<=))
import Grisette.Internal.Core.Control.Monad.Class.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.BitCast
  ( BitCast (bitCast),
    BitCastOr,
    bitCastOrCanonical,
  )
import Grisette.Internal.Core.Data.Class.IEEEFP (fpIsNaN)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.SymIEEEFP
  ( SymIEEEFPTraits (symFpIsNaN),
  )
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, tryMerge)
import Grisette.Internal.SymPrim.BV (IntN, WordN, WordN16, WordN32, WordN64)
import Grisette.Internal.SymPrim.FP
  ( FP,
    FP16,
    FP32,
    FP64,
    NotRepresentableFPError (NaNError),
    ValidFP,
  )
import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN)
import Grisette.Internal.SymPrim.SymFP (SymFP)

-- | Bitcasting a value. If the value cannot be precisely bitcast, throw an
-- error.
class
  (MonadError e m, TryMerge m, Mergeable b, BitCastOr a b) =>
  SafeBitCast e a b m
  where
  safeBitCast :: a -> m b

instance
  ( ValidFP eb sb,
    r ~ (eb + sb),
    KnownNat r,
    1 <= r,
    TryMerge m,
    MonadError NotRepresentableFPError m
  ) =>
  SafeBitCast NotRepresentableFPError (FP eb sb) (WordN r) m
  where
  safeBitCast :: FP eb sb -> m (WordN r)
safeBitCast FP eb sb
a
    | FP eb sb -> Bool
forall a. RealFloat a => a -> Bool
fpIsNaN FP eb sb
a = m (WordN r) -> m (WordN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (WordN r) -> m (WordN r)) -> m (WordN r) -> m (WordN r)
forall a b. (a -> b) -> a -> b
$ NotRepresentableFPError -> m (WordN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError
    | Bool
otherwise = m (WordN r) -> m (WordN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (WordN r) -> m (WordN r)) -> m (WordN r) -> m (WordN r)
forall a b. (a -> b) -> a -> b
$ WordN r -> m (WordN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (WordN r -> m (WordN r)) -> WordN r -> m (WordN r)
forall a b. (a -> b) -> a -> b
$ FP eb sb -> WordN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical FP eb sb
a

instance
  ( ValidFP eb sb,
    r ~ (eb + sb),
    KnownNat r,
    1 <= r,
    TryMerge m,
    MonadError NotRepresentableFPError m
  ) =>
  SafeBitCast NotRepresentableFPError (FP eb sb) (IntN r) m
  where
  safeBitCast :: FP eb sb -> m (IntN r)
safeBitCast FP eb sb
a
    | FP eb sb -> Bool
forall a. RealFloat a => a -> Bool
fpIsNaN FP eb sb
a = m (IntN r) -> m (IntN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (IntN r) -> m (IntN r)) -> m (IntN r) -> m (IntN r)
forall a b. (a -> b) -> a -> b
$ NotRepresentableFPError -> m (IntN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError
    | Bool
otherwise = m (IntN r) -> m (IntN r)
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m (IntN r) -> m (IntN r)) -> m (IntN r) -> m (IntN r)
forall a b. (a -> b) -> a -> b
$ IntN r -> m (IntN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (IntN r -> m (IntN r)) -> IntN r -> m (IntN r)
forall a b. (a -> b) -> a -> b
$ FP eb sb -> IntN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical FP eb sb
a

#define SAFE_BIT_CAST_VIA_INTERMEDIATE(from, to, intermediate) \
instance \
  (MonadError NotRepresentableFPError m, TryMerge m) => \
  SafeBitCast NotRepresentableFPError from to m \
  where \
  safeBitCast a = do \
    r :: intermediate <- safeBitCast a; \
    tryMerge $ return $ bitCast r

#if 1
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP64, Word64, WordN64)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP64, Int64, WordN64)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP64, Double, WordN64)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP32, Word32, WordN32)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP32, Int32, WordN32)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP32, Float, WordN32)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP16, Word16, WordN16)
SAFE_BIT_CAST_VIA_INTERMEDIATE(FP16, Int16, WordN16)
#endif

instance
  ( ValidFP eb sb,
    r ~ (eb + sb),
    KnownNat r,
    1 <= r,
    MonadUnion m,
    MonadError NotRepresentableFPError m
  ) =>
  SafeBitCast NotRepresentableFPError (SymFP eb sb) (SymWordN r) m
  where
  safeBitCast :: SymFP eb sb -> m (SymWordN r)
safeBitCast SymFP eb sb
a =
    SymBool -> m (SymWordN r) -> m (SymWordN r) -> m (SymWordN r)
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymFP eb sb -> SymBool
forall a. SymIEEEFPTraits a => a -> SymBool
symFpIsNaN SymFP eb sb
a)
      (NotRepresentableFPError -> m (SymWordN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError)
      (SymWordN r -> m (SymWordN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN r -> m (SymWordN r)) -> SymWordN r -> m (SymWordN r)
forall a b. (a -> b) -> a -> b
$ SymFP eb sb -> SymWordN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical SymFP eb sb
a)

instance
  ( ValidFP eb sb,
    r ~ (eb + sb),
    KnownNat r,
    1 <= r,
    MonadUnion m,
    MonadError NotRepresentableFPError m
  ) =>
  SafeBitCast NotRepresentableFPError (SymFP eb sb) (SymIntN r) m
  where
  safeBitCast :: SymFP eb sb -> m (SymIntN r)
safeBitCast SymFP eb sb
a =
    SymBool -> m (SymIntN r) -> m (SymIntN r) -> m (SymIntN r)
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymFP eb sb -> SymBool
forall a. SymIEEEFPTraits a => a -> SymBool
symFpIsNaN SymFP eb sb
a)
      (NotRepresentableFPError -> m (SymIntN r)
forall a. NotRepresentableFPError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NotRepresentableFPError
NaNError)
      (SymIntN r -> m (SymIntN r)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN r -> m (SymIntN r)) -> SymIntN r -> m (SymIntN r)
forall a b. (a -> b) -> a -> b
$ SymFP eb sb -> SymIntN r
forall from to. BitCastOrCanonical from to => from -> to
bitCastOrCanonical SymFP eb sb
a)