{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Eta reduce" #-}

-- |
-- Module      :   Grisette.Internal.Unified.Class.UnifiedSafeFromFP
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Unified.Class.UnifiedSafeFromFP
  ( UnifiedSafeFromFP (..),
    safeFromFP,
  )
where

import Control.Monad.Error.Class (MonadError)
import GHC.TypeNats (KnownNat, type (<=))
import Grisette.Internal.Core.Data.Class.SafeFromFP (SafeFromFP)
import qualified Grisette.Internal.Core.Data.Class.SafeFromFP as SafeFromFP
import Grisette.Internal.SymPrim.AlgReal (AlgReal)
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.FP
  ( FP,
    FPRoundingMode,
    NotRepresentableFPError,
    ValidFP,
  )
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal)
import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN)
import Grisette.Internal.SymPrim.SymFP (SymFP, SymFPRoundingMode)
import Grisette.Internal.SymPrim.SymInteger (SymInteger)
import Grisette.Internal.Unified.Class.UnifiedSimpleMergeable
  ( UnifiedBranching (withBaseBranching),
  )
import Grisette.Internal.Unified.EvalModeTag (EvalModeTag (S))
import Grisette.Internal.Unified.Util (withMode)

-- | Unified `Grisette.Internal.Core.Data.Class.SafeFromFP.safeFromFP`
-- operation.
--
-- This function isn't able to infer the mode, so you need to provide the mode
-- explicitly. For example:
--
-- > safeFromFP @mode mode fp
safeFromFP ::
  forall mode e a fp fprd m.
  (UnifiedSafeFromFP mode e a fp fprd m, MonadError e m) =>
  fprd ->
  fp ->
  m a
safeFromFP :: forall (mode :: EvalModeTag) e a fp fprd (m :: * -> *).
(UnifiedSafeFromFP mode e a fp fprd m, MonadError e m) =>
fprd -> fp -> m a
safeFromFP fprd
rd fp
fp =
  forall (mode :: EvalModeTag) e a fp fprd (m :: * -> *) r.
UnifiedSafeFromFP mode e a fp fprd m =>
(SafeFromFP e a fp fprd m => r) -> r
withBaseSafeFromFP @mode @e @a @fp @fprd @m ((SafeFromFP e a fp fprd m => m a) -> m a)
-> (SafeFromFP e a fp fprd m => m a) -> m a
forall a b. (a -> b) -> a -> b
$
    fprd -> fp -> m a
forall e a fp fprd (m :: * -> *).
SafeFromFP e a fp fprd m =>
fprd -> fp -> m a
SafeFromFP.safeFromFP fprd
rd fp
fp

-- | A class that provides unified safe conversion from floating points.
--
-- We use this type class to help resolve the constraints for `SafeFromFP`.
class UnifiedSafeFromFP (mode :: EvalModeTag) e a fp fprd m where
  withBaseSafeFromFP :: ((SafeFromFP e a fp fprd m) => r) -> r

instance
  {-# INCOHERENT #-}
  (UnifiedBranching mode m, SafeFromFP e a fp fprd m) =>
  UnifiedSafeFromFP mode e a fp fprd m
  where
  withBaseSafeFromFP :: forall r. (SafeFromFP e a fp fprd m => r) -> r
withBaseSafeFromFP SafeFromFP e a fp fprd m => r
r = r
SafeFromFP e a fp fprd m => r
r

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching mode m,
    ValidFP eb sb
  ) =>
  UnifiedSafeFromFP
    mode
    NotRepresentableFPError
    Integer
    (FP eb sb)
    FPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError Integer (FP eb sb) FPRoundingMode m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError Integer (FP eb sb) FPRoundingMode m =>
r
r =
    forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError Integer (FP eb sb) FPRoundingMode m =>
r
r) (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError Integer (FP eb sb) FPRoundingMode m =>
r
r)

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching mode m,
    ValidFP eb sb
  ) =>
  UnifiedSafeFromFP
    mode
    NotRepresentableFPError
    AlgReal
    (FP eb sb)
    FPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError AlgReal (FP eb sb) FPRoundingMode m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError AlgReal (FP eb sb) FPRoundingMode m =>
r
r =
    forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError AlgReal (FP eb sb) FPRoundingMode m =>
r
r) (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError AlgReal (FP eb sb) FPRoundingMode m =>
r
r)

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching mode m,
    ValidFP eb sb,
    KnownNat n,
    1 <= n
  ) =>
  UnifiedSafeFromFP
    mode
    NotRepresentableFPError
    (IntN n)
    (FP eb sb)
    FPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError (IntN n) (FP eb sb) FPRoundingMode m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError (IntN n) (FP eb sb) FPRoundingMode m =>
r
r =
    forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError (IntN n) (FP eb sb) FPRoundingMode m =>
r
r) (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError (IntN n) (FP eb sb) FPRoundingMode m =>
r
r)

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching mode m,
    ValidFP eb sb,
    KnownNat n,
    1 <= n
  ) =>
  UnifiedSafeFromFP
    mode
    NotRepresentableFPError
    (WordN n)
    (FP eb sb)
    FPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError (WordN n) (FP eb sb) FPRoundingMode m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError (WordN n) (FP eb sb) FPRoundingMode m =>
r
r =
    forall (mode :: EvalModeTag) r.
DecideEvalMode mode =>
((mode ~ 'C) => r) -> ((mode ~ 'S) => r) -> r
withMode @mode (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError (WordN n) (FP eb sb) FPRoundingMode m =>
r
r) (forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @mode @m r
If (IsConMode mode) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError (WordN n) (FP eb sb) FPRoundingMode m =>
r
r)

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching 'S m,
    ValidFP eb sb
  ) =>
  UnifiedSafeFromFP
    'S
    NotRepresentableFPError
    SymInteger
    (SymFP eb sb)
    SymFPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError
   SymInteger
   (SymFP eb sb)
   SymFPRoundingMode
   m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError
  SymInteger
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r = forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @'S @m r
If (IsConMode 'S) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError
  SymInteger
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching 'S m,
    ValidFP eb sb
  ) =>
  UnifiedSafeFromFP
    'S
    NotRepresentableFPError
    SymAlgReal
    (SymFP eb sb)
    SymFPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError
   SymAlgReal
   (SymFP eb sb)
   SymFPRoundingMode
   m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError
  SymAlgReal
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r = forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @'S @m r
If (IsConMode 'S) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError
  SymAlgReal
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching 'S m,
    ValidFP eb sb,
    KnownNat n,
    1 <= n
  ) =>
  UnifiedSafeFromFP
    'S
    NotRepresentableFPError
    (SymIntN n)
    (SymFP eb sb)
    SymFPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError
   (SymIntN n)
   (SymFP eb sb)
   SymFPRoundingMode
   m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError
  (SymIntN n)
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r = forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @'S @m r
If (IsConMode 'S) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError
  (SymIntN n)
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r

instance
  ( MonadError NotRepresentableFPError m,
    UnifiedBranching 'S m,
    ValidFP eb sb,
    KnownNat n,
    1 <= n
  ) =>
  UnifiedSafeFromFP
    'S
    NotRepresentableFPError
    (SymWordN n)
    (SymFP eb sb)
    SymFPRoundingMode
    m
  where
  withBaseSafeFromFP :: forall r.
(SafeFromFP
   NotRepresentableFPError
   (SymWordN n)
   (SymFP eb sb)
   SymFPRoundingMode
   m =>
 r)
-> r
withBaseSafeFromFP SafeFromFP
  NotRepresentableFPError
  (SymWordN n)
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r = forall (mode :: EvalModeTag) (m :: * -> *) r.
UnifiedBranching mode m =>
(If (IsConMode mode) (TryMerge m) (SymBranching m) => r) -> r
withBaseBranching @'S @m r
If (IsConMode 'S) (TryMerge m) (SymBranching m) => r
SafeFromFP
  NotRepresentableFPError
  (SymWordN n)
  (SymFP eb sb)
  SymFPRoundingMode
  m =>
r
r