{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.SafeFdiv
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SafeFdiv
  ( SafeFdiv (..),
    FdivOr (..),
    fdivOrZero,
    recipOrZero,
  )
where

import Control.Exception (ArithException (RatioZeroDenominator), throw)
import Control.Monad.Error.Class (MonadError (throwError))
import Grisette.Internal.Core.Control.Monad.Class.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.ITEOp (ITEOp (symIte))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.Core.Data.Class.SymEq (SymEq ((.==)))
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge, tryMerge)
import Grisette.Internal.SymPrim.AlgReal
  ( AlgReal (AlgExactRational),
    UnsupportedAlgRealOperation (UnsupportedAlgRealOperation),
  )
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( PEvalFractionalTerm (pevalFdivTerm, pevalRecipTerm),
  )
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal))

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Control.Monad.Except
-- >>> import Control.Exception

-- | Safe fractional with default values returned on exception.
class FdivOr a where
  -- | Safe '/' with default values returned on exception.
  --
  -- >>> fdivOr "d" "a" "b" :: SymAlgReal
  -- (ite (= b 0.0) d (fdiv a b))
  fdivOr :: a -> a -> a -> a

  -- | Safe 'recip' with default values returned on exception.
  --
  -- >>> recipOr "d" "a" :: SymAlgReal
  -- (ite (= a 0.0) d (recip a))
  recipOr :: a -> a -> a

-- | Safe '/' with 0 returned on exception.
fdivOrZero :: (FdivOr a, Num a) => a -> a -> a
fdivOrZero :: forall a. (FdivOr a, Num a) => a -> a -> a
fdivOrZero a
l = a -> a -> a -> a
forall a. FdivOr a => a -> a -> a -> a
fdivOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l) a
l

-- | Safe 'recip' with 0 returned on exception.
recipOrZero :: (FdivOr a, Num a) => a -> a
recipOrZero :: forall a. (FdivOr a, Num a) => a -> a
recipOrZero a
v = a -> a -> a
forall a. FdivOr a => a -> a -> a
recipOr (a
v a -> a -> a
forall a. Num a => a -> a -> a
- a
v) a
v

-- | Safe fractional division with monadic error handling in multi-path
-- execution. These procedures throw an exception when the denominator is zero.
-- The result should be able to handle errors with `MonadError`.
class (MonadError e m, TryMerge m, Mergeable a) => SafeFdiv e a m where
  -- | Safe fractional division with monadic error handling in multi-path
  -- execution.
  --
  -- >>> safeFdiv "a" "b" :: ExceptT ArithException Union SymAlgReal
  -- ExceptT {If (= b 0.0) (Left Ratio has zero denominator) (Right (fdiv a b))}
  safeFdiv :: a -> a -> m a

  -- | Safe fractional reciprocal with monadic error handling in multi-path
  -- execution.
  --
  -- >>> safeRecip "a" :: ExceptT ArithException Union SymAlgReal
  -- ExceptT {If (= a 0.0) (Left Ratio has zero denominator) (Right (recip a))}
  safeRecip :: a -> m a
  default safeRecip :: (Fractional a) => a -> m a
  safeRecip = a -> a -> m a
forall e a (m :: * -> *). SafeFdiv e a m => a -> a -> m a
safeFdiv (Rational -> a
forall a. Fractional a => Rational -> a
fromRational Rational
1)
  {-# INLINE safeRecip #-}

  {-# MINIMAL safeFdiv #-}

instance FdivOr AlgReal where
  fdivOr :: AlgReal -> AlgReal -> AlgReal -> AlgReal
fdivOr AlgReal
d (AlgExactRational Rational
l) (AlgExactRational Rational
r)
    | Rational
r Rational -> Rational -> Bool
forall a. Eq a => a -> a -> Bool
/= Rational
0 = Rational -> AlgReal
AlgExactRational (Rational
l Rational -> Rational -> Rational
forall a. Fractional a => a -> a -> a
/ Rational
r)
    | Bool
otherwise = AlgReal
d
  fdivOr AlgReal
d AlgReal
l AlgReal
r =
    -- Throw the error because the user should never construct an AlgReal
    -- other than AlgExactRational.
    UnsupportedAlgRealOperation -> AlgReal
forall a e. (?callStack::CallStack, Exception e) => e -> a
throw (UnsupportedAlgRealOperation -> AlgReal)
-> UnsupportedAlgRealOperation -> AlgReal
forall a b. (a -> b) -> a -> b
$
      String -> String -> UnsupportedAlgRealOperation
UnsupportedAlgRealOperation String
"fdivOr" (String -> UnsupportedAlgRealOperation)
-> String -> UnsupportedAlgRealOperation
forall a b. (a -> b) -> a -> b
$
        AlgReal -> String
forall a. Show a => a -> String
show AlgReal
d String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" and " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> AlgReal -> String
forall a. Show a => a -> String
show AlgReal
l String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" and " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> AlgReal -> String
forall a. Show a => a -> String
show AlgReal
r
  {-# INLINE fdivOr #-}
  recipOr :: AlgReal -> AlgReal -> AlgReal
recipOr AlgReal
d (AlgExactRational Rational
l)
    | Rational
l Rational -> Rational -> Bool
forall a. Eq a => a -> a -> Bool
/= Rational
0 = Rational -> AlgReal
AlgExactRational (Rational -> Rational
forall a. Fractional a => a -> a
recip Rational
l)
    | Bool
otherwise = AlgReal
d
  recipOr AlgReal
d AlgReal
l =
    UnsupportedAlgRealOperation -> AlgReal
forall a e. (?callStack::CallStack, Exception e) => e -> a
throw (UnsupportedAlgRealOperation -> AlgReal)
-> UnsupportedAlgRealOperation -> AlgReal
forall a b. (a -> b) -> a -> b
$ String -> String -> UnsupportedAlgRealOperation
UnsupportedAlgRealOperation String
"recipOr" (String -> UnsupportedAlgRealOperation)
-> String -> UnsupportedAlgRealOperation
forall a b. (a -> b) -> a -> b
$ AlgReal -> String
forall a. Show a => a -> String
show AlgReal
d String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" and " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> AlgReal -> String
forall a. Show a => a -> String
show AlgReal
l
  {-# INLINE recipOr #-}

instance
  ( MonadError ArithException m,
    TryMerge m
  ) =>
  SafeFdiv ArithException AlgReal m
  where
  safeFdiv :: AlgReal -> AlgReal -> m AlgReal
safeFdiv (AlgExactRational Rational
l) (AlgExactRational Rational
r)
    | Rational
r Rational -> Rational -> Bool
forall a. Eq a => a -> a -> Bool
/= Rational
0 =
        AlgReal -> m AlgReal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AlgReal -> m AlgReal) -> AlgReal -> m AlgReal
forall a b. (a -> b) -> a -> b
$ Rational -> AlgReal
AlgExactRational (Rational
l Rational -> Rational -> Rational
forall a. Fractional a => a -> a -> a
/ Rational
r)
    | Bool
otherwise = m AlgReal -> m AlgReal
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m AlgReal -> m AlgReal) -> m AlgReal -> m AlgReal
forall a b. (a -> b) -> a -> b
$ ArithException -> m AlgReal
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
RatioZeroDenominator
  safeFdiv AlgReal
l AlgReal
r =
    -- Throw the error because the user should never construct an AlgReal
    -- other than AlgExactRational.
    UnsupportedAlgRealOperation -> m AlgReal
forall a e. (?callStack::CallStack, Exception e) => e -> a
throw (UnsupportedAlgRealOperation -> m AlgReal)
-> UnsupportedAlgRealOperation -> m AlgReal
forall a b. (a -> b) -> a -> b
$
      String -> String -> UnsupportedAlgRealOperation
UnsupportedAlgRealOperation String
"safeFdiv" (String -> UnsupportedAlgRealOperation)
-> String -> UnsupportedAlgRealOperation
forall a b. (a -> b) -> a -> b
$
        AlgReal -> String
forall a. Show a => a -> String
show AlgReal
l String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" and " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> AlgReal -> String
forall a. Show a => a -> String
show AlgReal
r
  {-# INLINE safeFdiv #-}
  safeRecip :: AlgReal -> m AlgReal
safeRecip (AlgExactRational Rational
l)
    | Rational
l Rational -> Rational -> Bool
forall a. Eq a => a -> a -> Bool
/= Rational
0 =
        AlgReal -> m AlgReal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AlgReal -> m AlgReal) -> AlgReal -> m AlgReal
forall a b. (a -> b) -> a -> b
$ Rational -> AlgReal
AlgExactRational (Rational -> Rational
forall a. Fractional a => a -> a
recip Rational
l)
    | Bool
otherwise = m AlgReal -> m AlgReal
forall (m :: * -> *) a. (TryMerge m, Mergeable a) => m a -> m a
tryMerge (m AlgReal -> m AlgReal) -> m AlgReal -> m AlgReal
forall a b. (a -> b) -> a -> b
$ ArithException -> m AlgReal
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
RatioZeroDenominator
  safeRecip AlgReal
l =
    UnsupportedAlgRealOperation -> m AlgReal
forall a e. (?callStack::CallStack, Exception e) => e -> a
throw (UnsupportedAlgRealOperation -> m AlgReal)
-> UnsupportedAlgRealOperation -> m AlgReal
forall a b. (a -> b) -> a -> b
$ String -> String -> UnsupportedAlgRealOperation
UnsupportedAlgRealOperation String
"safeRecip" (String -> UnsupportedAlgRealOperation)
-> String -> UnsupportedAlgRealOperation
forall a b. (a -> b) -> a -> b
$ AlgReal -> String
forall a. Show a => a -> String
show AlgReal
l

instance FdivOr SymAlgReal where
  fdivOr :: SymAlgReal -> SymAlgReal -> SymAlgReal -> SymAlgReal
fdivOr SymAlgReal
d (SymAlgReal Term AlgReal
lt) r :: SymAlgReal
r@(SymAlgReal Term AlgReal
rt) =
    SymBool -> SymAlgReal -> SymAlgReal -> SymAlgReal
forall v. ITEOp v => SymBool -> v -> v -> v
symIte (SymAlgReal
r SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== AlgReal -> SymAlgReal
forall c t. Solvable c t => c -> t
con AlgReal
0) SymAlgReal
d (Term AlgReal -> SymAlgReal
SymAlgReal (Term AlgReal -> SymAlgReal) -> Term AlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$ Term AlgReal -> Term AlgReal -> Term AlgReal
forall t. PEvalFractionalTerm t => Term t -> Term t -> Term t
pevalFdivTerm Term AlgReal
lt Term AlgReal
rt)
  recipOr :: SymAlgReal -> SymAlgReal -> SymAlgReal
recipOr SymAlgReal
d l :: SymAlgReal
l@(SymAlgReal Term AlgReal
lt) =
    SymBool -> SymAlgReal -> SymAlgReal -> SymAlgReal
forall v. ITEOp v => SymBool -> v -> v -> v
symIte (SymAlgReal
l SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== AlgReal -> SymAlgReal
forall c t. Solvable c t => c -> t
con AlgReal
0) SymAlgReal
d (Term AlgReal -> SymAlgReal
SymAlgReal (Term AlgReal -> SymAlgReal) -> Term AlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$ Term AlgReal -> Term AlgReal
forall t. PEvalFractionalTerm t => Term t -> Term t
pevalRecipTerm Term AlgReal
lt)

instance
  (MonadError ArithException m, MonadUnion m) =>
  SafeFdiv ArithException SymAlgReal m
  where
  safeFdiv :: SymAlgReal -> SymAlgReal -> m SymAlgReal
safeFdiv (SymAlgReal Term AlgReal
lt) r :: SymAlgReal
r@(SymAlgReal Term AlgReal
rt) =
    SymBool -> m SymAlgReal -> m SymAlgReal -> m SymAlgReal
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymAlgReal
r SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== AlgReal -> SymAlgReal
forall c t. Solvable c t => c -> t
con AlgReal
0)
      (ArithException -> m SymAlgReal
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
RatioZeroDenominator)
      (SymAlgReal -> m SymAlgReal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SymAlgReal -> m SymAlgReal) -> SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$ Term AlgReal -> SymAlgReal
SymAlgReal (Term AlgReal -> SymAlgReal) -> Term AlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$ Term AlgReal -> Term AlgReal -> Term AlgReal
forall t. PEvalFractionalTerm t => Term t -> Term t -> Term t
pevalFdivTerm Term AlgReal
lt Term AlgReal
rt)
  safeRecip :: SymAlgReal -> m SymAlgReal
safeRecip l :: SymAlgReal
l@(SymAlgReal Term AlgReal
lt) =
    SymBool -> m SymAlgReal -> m SymAlgReal -> m SymAlgReal
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymAlgReal
l SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== AlgReal -> SymAlgReal
forall c t. Solvable c t => c -> t
con AlgReal
0)
      (ArithException -> m SymAlgReal
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
RatioZeroDenominator)
      (SymAlgReal -> m SymAlgReal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SymAlgReal -> m SymAlgReal) -> SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$ Term AlgReal -> SymAlgReal
SymAlgReal (Term AlgReal -> SymAlgReal) -> Term AlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$ Term AlgReal -> Term AlgReal
forall t. PEvalFractionalTerm t => Term t -> Term t
pevalRecipTerm Term AlgReal
lt)