{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.SafeLogBase
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SafeLogBase
  ( SafeLogBase (..),
    LogBaseOr (..),
    logBaseOrZero,
  )
where

import Control.Exception (ArithException (RatioZeroDenominator))
import Control.Monad.Error.Class (MonadError (throwError))
import Grisette.Internal.Core.Control.Monad.Class.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.ITEOp (ITEOp (symIte))
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.SymEq (SymEq ((.==)))
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge)
import Grisette.Internal.SymPrim.Prim.Internal.Term
  ( FloatingUnaryOp (FloatingLog),
    PEvalFloatingTerm (pevalFloatingUnaryTerm),
    PEvalFractionalTerm (pevalFdivTerm),
  )
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal))

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Control.Monad.Except
-- >>> import Control.Exception

-- | Safe 'logBase' with default values returned on exception.
class LogBaseOr a where
  -- | Safe 'logBase' with default values returned on exception.
  --
  -- >>> logBaseOr "d" "base" "val" :: SymAlgReal
  -- (ite (= base 1.0) d (fdiv (log val) (log base)))
  logBaseOr :: a -> a -> a -> a

-- | Safe 'logBase' with 0 returned on exception.
logBaseOrZero :: (LogBaseOr a, Num a) => a -> a -> a
logBaseOrZero :: forall a. (LogBaseOr a, Num a) => a -> a -> a
logBaseOrZero a
l = a -> a -> a -> a
forall a. LogBaseOr a => a -> a -> a -> a
logBaseOr (a
l a -> a -> a
forall a. Num a => a -> a -> a
- a
l) a
l
{-# INLINE logBaseOrZero #-}

-- | Safe 'logBase' with monadic error handling in multi-path execution.
-- These procedures throw an exception when the base is 1.
-- The result should be able to handle errors with `MonadError`.
class (MonadError e m, TryMerge m, Mergeable a) => SafeLogBase e a m where
  -- | Safe 'logBase' with monadic error handling in multi-path execution.
  --
  -- >>> safeLogBase (ssym "base") (ssym "val") :: ExceptT ArithException Union SymAlgReal
  -- ExceptT {If (= base 1.0) (Left Ratio has zero denominator) (Right (fdiv (log val) (log base)))}
  safeLogBase :: a -> a -> m a
  safeLogBase = a -> a -> m a
forall a. HasCallStack => a
undefined
  {-# INLINE safeLogBase #-}

instance LogBaseOr SymAlgReal where
  logBaseOr :: SymAlgReal -> SymAlgReal -> SymAlgReal -> SymAlgReal
logBaseOr SymAlgReal
d base :: SymAlgReal
base@(SymAlgReal Term AlgReal
baset) (SymAlgReal Term AlgReal
at) =
    SymBool -> SymAlgReal -> SymAlgReal -> SymAlgReal
forall v. ITEOp v => SymBool -> v -> v -> v
symIte (SymAlgReal
base SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== SymAlgReal
1) SymAlgReal
d (SymAlgReal -> SymAlgReal) -> SymAlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$
      Term AlgReal -> SymAlgReal
SymAlgReal (Term AlgReal -> SymAlgReal) -> Term AlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$
        Term AlgReal -> Term AlgReal -> Term AlgReal
forall t. PEvalFractionalTerm t => Term t -> Term t -> Term t
pevalFdivTerm
          (FloatingUnaryOp -> Term AlgReal -> Term AlgReal
forall t.
PEvalFloatingTerm t =>
FloatingUnaryOp -> Term t -> Term t
pevalFloatingUnaryTerm FloatingUnaryOp
FloatingLog Term AlgReal
at)
          (FloatingUnaryOp -> Term AlgReal -> Term AlgReal
forall t.
PEvalFloatingTerm t =>
FloatingUnaryOp -> Term t -> Term t
pevalFloatingUnaryTerm FloatingUnaryOp
FloatingLog Term AlgReal
baset)
  {-# INLINE logBaseOr #-}

instance
  (MonadError ArithException m, MonadUnion m) =>
  SafeLogBase ArithException SymAlgReal m
  where
  safeLogBase :: SymAlgReal -> SymAlgReal -> m SymAlgReal
safeLogBase base :: SymAlgReal
base@(SymAlgReal Term AlgReal
baset) (SymAlgReal Term AlgReal
at) =
    SymBool -> m SymAlgReal -> m SymAlgReal -> m SymAlgReal
forall (u :: * -> *) a.
(SymBranching u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymAlgReal
base SymAlgReal -> SymAlgReal -> SymBool
forall a. SymEq a => a -> a -> SymBool
.== SymAlgReal
1) (ArithException -> m SymAlgReal
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
RatioZeroDenominator) (m SymAlgReal -> m SymAlgReal) -> m SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$
      SymAlgReal -> m SymAlgReal
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SymAlgReal -> m SymAlgReal) -> SymAlgReal -> m SymAlgReal
forall a b. (a -> b) -> a -> b
$
        Term AlgReal -> SymAlgReal
SymAlgReal (Term AlgReal -> SymAlgReal) -> Term AlgReal -> SymAlgReal
forall a b. (a -> b) -> a -> b
$
          Term AlgReal -> Term AlgReal -> Term AlgReal
forall t. PEvalFractionalTerm t => Term t -> Term t -> Term t
pevalFdivTerm
            (FloatingUnaryOp -> Term AlgReal -> Term AlgReal
forall t.
PEvalFloatingTerm t =>
FloatingUnaryOp -> Term t -> Term t
pevalFloatingUnaryTerm FloatingUnaryOp
FloatingLog Term AlgReal
at)
            (FloatingUnaryOp -> Term AlgReal -> Term AlgReal
forall t.
PEvalFloatingTerm t =>
FloatingUnaryOp -> Term t -> Term t
pevalFloatingUnaryTerm FloatingUnaryOp
FloatingLog Term AlgReal
baset)
  {-# INLINE safeLogBase #-}